Unverified Commit 09c5ddd6 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding missing named param check on ViT (#5196)

parent f9bee391
...@@ -11,7 +11,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -11,7 +11,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
"VisionTransformer", "VisionTransformer",
...@@ -111,6 +111,9 @@ def _vision_transformer( ...@@ -111,6 +111,9 @@ def _vision_transformer(
) -> VisionTransformer: ) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224) image_size = kwargs.pop("image_size", 224)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VisionTransformer( model = VisionTransformer(
image_size=image_size, image_size=image_size,
patch_size=patch_size, patch_size=patch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment