Unverified Commit 09c5ddd6 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding missing named param check on ViT (#5196)

parent f9bee391
......@@ -11,7 +11,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"VisionTransformer",
......@@ -111,6 +111,9 @@ def _vision_transformer(
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment