Unverified Commit fb7f9a16 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add MViT architecture in TorchVision (#6198)

* Adding MViT v2 architecture (#6105)

* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.

* Add support to MViT v1 (#6179)

* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.

* Fix documentation

* Remove unnecessary expected file.

* Skip big model test

* Rewrite the configuration logic to reduce LOC.

* Fix mypy
parent 074adabe
......@@ -465,6 +465,7 @@ pre-trained weights:
.. toctree::
:maxdepth: 1
models/video_mvit
models/video_resnet
|
......
Video MViT
==========
.. currentmodule:: torchvision.models.video
The MViT model is based on the
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
<https://arxiv.org/abs/2104.11227>`__ papers.
Model builders
--------------
The following model builders can be used to instantiate a MViT model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.MViT`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
mvit_v1_b
......@@ -152,7 +152,7 @@ def main(args):
split="train",
step_between_clips=1,
transform=transform_train,
frame_rate=15,
frame_rate=args.frame_rate,
extensions=(
"avi",
"mp4",
......@@ -189,7 +189,7 @@ def main(args):
split="val",
step_between_clips=1,
transform=transform_test,
frame_rate=15,
frame_rate=args.frame_rate,
extensions=(
"avi",
"mp4",
......@@ -324,6 +324,7 @@ def parse_args():
parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
parser.add_argument(
"--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
)
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -87,6 +87,7 @@ def test_schema_meta_validation(model_fn):
"license",
"_metrics",
"min_size",
"min_temporal_size",
"num_params",
"recipe",
"unquantized",
......@@ -180,7 +181,7 @@ def test_transforms_jit(model_fn):
"input_shape": (1, 3, 520, 520),
},
"video": {
"input_shape": (1, 4, 3, 112, 112),
"input_shape": (1, 3, 4, 112, 112),
},
"optical_flow": {
"input_shape": (1, 3, 128, 128),
......@@ -194,6 +195,8 @@ def test_transforms_jit(model_fn):
if module_name == "optical_flow":
args = (x, x)
else:
if module_name == "video":
x = x.permute(0, 2, 1, 3, 4)
args = (x,)
problematic_weights = []
......
......@@ -309,6 +309,9 @@ _model_params = {
"image_size": 56,
"input_shape": (1, 3, 56, 56),
},
"mvit_v1_b": {
"input_shape": (1, 3, 16, 224, 224),
},
}
# speeding up slow models:
slow_models = [
......@@ -830,6 +833,8 @@ def test_video_model(model_fn, dev):
"num_classes": 50,
}
model_name = model_fn.__name__
if SKIP_BIG_MODEL and model_name in skipped_big_models:
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape")
......
from .mvit import *
from .resnet import *
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment