You need to sign in or sign up before continuing.
Unverified Commit 7e8186e0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add support of MViTv2 video variants (#6373)

* Extending to support MViTv2

* Fix docs, mypy and linter

* Refactor the relative positional code.

* Code refactoring.

* Rename vars.

* Update docs.

* Replace assert with exception.

* Updat docs.

* Minor refactoring.

* Remove the square input limitation.

* Moving methods around.

* Modify the shortcut in the attention layer.

* Add ported weights.

* Introduce a `residual_cls` config on the attention layer.

* Make the patch_embed kernel/padding/stride configurable.

* Apply changes from code-review.

* Remove stale todo.
parent 6908129a
......@@ -12,7 +12,7 @@ The MViT model is based on the
Model builders
--------------
The following model builders can be used to instantiate a MViT model, with or
The following model builders can be used to instantiate a MViT v1 or v2 model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.MViT`` base class. Please refer to the `source
code
......@@ -24,3 +24,4 @@ more details about this class.
:template: function.rst
mvit_v1_b
mvit_v2_s
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -309,6 +309,9 @@ _model_params = {
"mvit_v1_b": {
"input_shape": (1, 3, 16, 224, 224),
},
"mvit_v2_s": {
"input_shape": (1, 3, 16, 224, 224),
},
}
# speeding up slow models:
slow_models = [
......
......@@ -19,12 +19,11 @@ __all__ = [
"MViT",
"MViT_V1_B_Weights",
"mvit_v1_b",
"MViT_V2_S_Weights",
"mvit_v2_s",
]
# TODO: Consider handle 2d input if Temporal is 1
@dataclass
class MSBlockConfig:
num_heads: int
......@@ -106,28 +105,121 @@ class Pool(nn.Module):
return x, (T, H, W)
def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
if embedding.shape[0] == d:
return embedding
return (
nn.functional.interpolate(
embedding.permute(1, 0).unsqueeze(0),
size=d,
mode="linear",
)
.squeeze(0)
.permute(1, 0)
)
def _add_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
q_thw: Tuple[int, int, int],
k_thw: Tuple[int, int, int],
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
rel_pos_t: torch.Tensor,
) -> torch.Tensor:
# Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
q_t, q_h, q_w = q_thw
k_t, k_h, k_w = k_thw
dh = int(2 * max(q_h, k_h) - 1)
dw = int(2 * max(q_w, k_w) - 1)
dt = int(2 * max(q_t, k_t) - 1)
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
q_t_ratio = max(k_t / q_t, 1.0)
k_t_ratio = max(q_t / k_t, 1.0)
dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio
# Intepolate rel pos if needed.
rel_pos_h = _interpolate(rel_pos_h, dh)
rel_pos_w = _interpolate(rel_pos_w, dw)
rel_pos_t = _interpolate(rel_pos_t, dt)
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
Rt = rel_pos_t[dist_t.long()]
B, n_head, _, dim = q.shape
r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h]
rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w]
# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
# Combine rel pos.
rel_pos = (
rel_h_q[:, :, :, :, :, None, :, None]
+ rel_w_q[:, :, :, :, :, None, None, :]
+ rel_q_t[:, :, :, :, :, :, None, None]
).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)
# Add it to attention
attn[:, :, 1:, 1:] += rel_pos
return attn
def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
if residual_with_cls_embed:
x.add_(shortcut)
else:
x[:, :, 1:, :] += shortcut[:, :, 1:, :]
return x
torch.fx.wrap("_add_rel_pos")
torch.fx.wrap("_add_shortcut")
class MultiscaleAttention(nn.Module):
def __init__(
self,
input_size: List[int],
embed_dim: int,
output_dim: int,
num_heads: int,
kernel_q: List[int],
kernel_kv: List[int],
stride_q: List[int],
stride_kv: List[int],
residual_pool: bool,
residual_with_cls_embed: bool,
rel_pos_embed: bool,
dropout: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.output_dim = output_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.head_dim = output_dim // num_heads
self.scaler = 1.0 / math.sqrt(self.head_dim)
self.residual_pool = residual_pool
self.residual_with_cls_embed = residual_with_cls_embed
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)]
self.qkv = nn.Linear(embed_dim, 3 * output_dim)
layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)]
if dropout > 0.0:
layers.append(nn.Dropout(dropout, inplace=True))
self.project = nn.Sequential(*layers)
......@@ -177,24 +269,52 @@ class MultiscaleAttention(nn.Module):
norm_layer(self.head_dim),
)
self.rel_pos_h: Optional[nn.Parameter] = None
self.rel_pos_w: Optional[nn.Parameter] = None
self.rel_pos_t: Optional[nn.Parameter] = None
if rel_pos_embed:
size = max(input_size[1:])
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
spatial_dim = 2 * max(q_size, kv_size) - 1
temporal_dim = 2 * input_size[0] - 1
self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
nn.init.trunc_normal_(self.rel_pos_t, std=0.02)
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2)
if self.pool_k is not None:
k = self.pool_k(k, thw)[0]
k, k_thw = self.pool_k(k, thw)
else:
k_thw = thw
if self.pool_v is not None:
v = self.pool_v(v, thw)[0]
if self.pool_q is not None:
q, thw = self.pool_q(q, thw)
attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None:
attn = _add_rel_pos(
attn,
q,
thw,
k_thw,
self.rel_pos_h,
self.rel_pos_w,
self.rel_pos_t,
)
attn = attn.softmax(dim=-1)
x = torch.matmul(attn, v)
if self.residual_pool:
x.add_(q)
x = x.transpose(1, 2).reshape(B, -1, C)
_add_shortcut(x, q, self.residual_with_cls_embed)
x = x.transpose(1, 2).reshape(B, -1, self.output_dim)
x = self.project(x)
return x, thw
......@@ -203,13 +323,18 @@ class MultiscaleAttention(nn.Module):
class MultiscaleBlock(nn.Module):
def __init__(
self,
input_size: List[int],
cnf: MSBlockConfig,
residual_pool: bool,
residual_with_cls_embed: bool,
rel_pos_embed: bool,
proj_after_attn: bool,
dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
self.proj_after_attn = proj_after_attn
self.pool_skip: Optional[nn.Module] = None
if _prod(cnf.stride_q) > 1:
......@@ -219,24 +344,30 @@ class MultiscaleBlock(nn.Module):
nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type]
)
attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels
self.norm1 = norm_layer(cnf.input_channels)
self.norm2 = norm_layer(cnf.input_channels)
self.norm2 = norm_layer(attn_dim)
self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)
self.attn = MultiscaleAttention(
input_size,
cnf.input_channels,
attn_dim,
cnf.num_heads,
kernel_q=cnf.kernel_q,
kernel_kv=cnf.kernel_kv,
stride_q=cnf.stride_q,
stride_kv=cnf.stride_kv,
rel_pos_embed=rel_pos_embed,
residual_pool=residual_pool,
residual_with_cls_embed=residual_with_cls_embed,
dropout=dropout,
norm_layer=norm_layer,
)
self.mlp = MLP(
cnf.input_channels,
[4 * cnf.input_channels, cnf.output_channels],
attn_dim,
[4 * attn_dim, cnf.output_channels],
activation_layer=nn.GELU,
dropout=dropout,
inplace=None,
......@@ -249,36 +380,45 @@ class MultiscaleBlock(nn.Module):
self.project = nn.Linear(cnf.input_channels, cnf.output_channels)
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
x_attn, thw_new = self.attn(x_norm1, thw)
x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1)
x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0]
x = x_skip + self.stochastic_depth(x_attn)
x = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
x, thw = self.attn(x, thw)
x = x_skip + self.stochastic_depth(x)
x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2)
x_norm = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
x_proj = x if self.project is None else self.project(x_norm)
return x_proj + self.stochastic_depth(self.mlp(x_norm)), thw
return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new
class PositionalEncoding(nn.Module):
def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int) -> None:
def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
super().__init__()
self.spatial_size = spatial_size
self.temporal_size = temporal_size
self.class_token = nn.Parameter(torch.zeros(embed_size))
self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
self.class_pos = nn.Parameter(torch.zeros(embed_size))
self.spatial_pos: Optional[nn.Parameter] = None
self.temporal_pos: Optional[nn.Parameter] = None
self.class_pos: Optional[nn.Parameter] = None
if not rel_pos_embed:
self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
self.class_pos = nn.Parameter(torch.zeros(embed_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
hw_size, embed_size = self.spatial_pos.shape
pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
return torch.cat((class_token, x), dim=1).add_(pos_embedding)
x = torch.cat((class_token, x), dim=1)
if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
hw_size, embed_size = self.spatial_pos.shape
pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
x.add_(pos_embedding)
return x
class MViT(nn.Module):
......@@ -288,12 +428,18 @@ class MViT(nn.Module):
temporal_size: int,
block_setting: Sequence[MSBlockConfig],
residual_pool: bool,
residual_with_cls_embed: bool,
rel_pos_embed: bool,
proj_after_attn: bool,
dropout: float = 0.5,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
num_classes: int = 400,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7),
patch_embed_stride: Tuple[int, int, int] = (2, 4, 4),
patch_embed_padding: Tuple[int, int, int] = (1, 3, 3),
) -> None:
"""
MViT main class.
......@@ -303,12 +449,19 @@ class MViT(nn.Module):
temporal_size (int): The temporal size ``T`` of the input.
block_setting (sequence of MSBlockConfig): The Network structure.
residual_pool (bool): If True, use MViTv2 pooling residual connection.
residual_with_cls_embed (bool): If True, the addition on the residual connection will include
the class embedding.
rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings.
proj_after_attn (bool): If True, apply the projection after the attention.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
num_classes (int): The number of classes.
block (callable, optional): Module specifying the layer which consists of the attention and mlp.
norm_layer (callable, optional): Module specifying the normalization layer to use.
patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input.
patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input.
patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input.
"""
super().__init__()
# This implementation employs a different parameterization scheme than the one used at PyTorch Video:
......@@ -330,16 +483,19 @@ class MViT(nn.Module):
self.conv_proj = nn.Conv3d(
in_channels=3,
out_channels=block_setting[0].input_channels,
kernel_size=(3, 7, 7),
stride=(2, 4, 4),
padding=(1, 3, 3),
kernel_size=patch_embed_kernel,
stride=patch_embed_stride,
padding=patch_embed_padding,
)
input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)]
# Spatio-Temporal Class Positional Encoding
self.pos_encoding = PositionalEncoding(
embed_size=block_setting[0].input_channels,
spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]),
temporal_size=temporal_size // self.conv_proj.stride[0],
spatial_size=(input_size[1], input_size[2]),
temporal_size=input_size[0],
rel_pos_embed=rel_pos_embed,
)
# Encoder module
......@@ -350,13 +506,20 @@ class MViT(nn.Module):
self.blocks.append(
block(
input_size=input_size,
cnf=cnf,
residual_pool=residual_pool,
residual_with_cls_embed=residual_with_cls_embed,
rel_pos_embed=rel_pos_embed,
proj_after_attn=proj_after_attn,
dropout=attention_dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
)
)
if len(cnf.stride_q) > 0:
input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)]
self.norm = norm_layer(block_setting[-1].output_channels)
# Classifier module
......@@ -380,6 +543,8 @@ class MViT(nn.Module):
nn.init.trunc_normal_(weights, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)
x = _unsqueeze(x, 5, 2)[0]
# patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])
x = self.conv_proj(x)
x = x.flatten(2).transpose(1, 2)
......@@ -420,6 +585,9 @@ def _mvit(
temporal_size=temporal_size,
block_setting=block_setting,
residual_pool=kwargs.pop("residual_pool", False),
residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True),
rel_pos_embed=kwargs.pop("rel_pos_embed", False),
proj_after_attn=kwargs.pop("proj_after_attn", False),
stochastic_depth_prob=stochastic_depth_prob,
**kwargs,
)
......@@ -461,6 +629,37 @@ class MViT_V1_B_Weights(WeightsEnum):
DEFAULT = KINETICS400_V1
class MViT_V2_S_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",
transforms=partial(
VideoClassification,
crop_size=(224, 224),
resize_size=(256,),
mean=(0.45, 0.45, 0.45),
std=(0.225, 0.225, 0.225),
),
meta={
"min_size": (224, 224),
"min_temporal_size": 16,
"categories": _KINETICS400_CATEGORIES,
"recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
),
"num_params": 34537744,
"_metrics": {
"Kinetics-400": {
"acc@1": 80.757,
"acc@5": 94.665,
}
},
},
)
DEFAULT = KINETICS400_V1
@register_model()
def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
"""
......@@ -548,6 +747,138 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T
temporal_size=16,
block_setting=block_setting,
residual_pool=False,
residual_with_cls_embed=False,
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
weights=weights,
progress=progress,
**kwargs,
)
@register_model()
def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
"""
Constructs a small MViTV2 architecture from
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
Args:
weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.MViT_V2_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.MViT``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
:members:
"""
weights = MViT_V2_S_Weights.verify(weights)
config: Dict[str, List] = {
"num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
"input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768],
"output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
"kernel_q": [
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
],
"kernel_kv": [
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
],
"stride_q": [
[1, 1, 1],
[1, 2, 2],
[1, 1, 1],
[1, 2, 2],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 2, 2],
[1, 1, 1],
],
"stride_kv": [
[1, 8, 8],
[1, 4, 4],
[1, 4, 4],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 1, 1],
[1, 1, 1],
],
}
block_setting = []
for i in range(len(config["num_heads"])):
block_setting.append(
MSBlockConfig(
num_heads=config["num_heads"][i],
input_channels=config["input_channels"][i],
output_channels=config["output_channels"][i],
kernel_q=config["kernel_q"][i],
kernel_kv=config["kernel_kv"][i],
stride_q=config["stride_q"][i],
stride_kv=config["stride_kv"][i],
)
)
return _mvit(
spatial_size=(224, 224),
temporal_size=16,
block_setting=block_setting,
residual_pool=True,
residual_with_cls_embed=False,
rel_pos_embed=True,
proj_after_attn=True,
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
weights=weights,
progress=progress,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment