Unverified Commit 7e8186e0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add support of MViTv2 video variants (#6373)

* Extending to support MViTv2

* Fix docs, mypy and linter

* Refactor the relative positional code.

* Code refactoring.

* Rename vars.

* Update docs.

* Replace assert with exception.

* Updat docs.

* Minor refactoring.

* Remove the square input limitation.

* Moving methods around.

* Modify the shortcut in the attention layer.

* Add ported weights.

* Introduce a `residual_cls` config on the attention layer.

* Make the patch_embed kernel/padding/stride configurable.

* Apply changes from code-review.

* Remove stale todo.
parent 6908129a
...@@ -12,7 +12,7 @@ The MViT model is based on the ...@@ -12,7 +12,7 @@ The MViT model is based on the
Model builders Model builders
-------------- --------------
The following model builders can be used to instantiate a MViT model, with or The following model builders can be used to instantiate a MViT v1 or v2 model, with or
without pre-trained weights. All the model builders internally rely on the without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.MViT`` base class. Please refer to the `source ``torchvision.models.video.MViT`` base class. Please refer to the `source
code code
...@@ -24,3 +24,4 @@ more details about this class. ...@@ -24,3 +24,4 @@ more details about this class.
:template: function.rst :template: function.rst
mvit_v1_b mvit_v1_b
mvit_v2_s
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -309,6 +309,9 @@ _model_params = { ...@@ -309,6 +309,9 @@ _model_params = {
"mvit_v1_b": { "mvit_v1_b": {
"input_shape": (1, 3, 16, 224, 224), "input_shape": (1, 3, 16, 224, 224),
}, },
"mvit_v2_s": {
"input_shape": (1, 3, 16, 224, 224),
},
} }
# speeding up slow models: # speeding up slow models:
slow_models = [ slow_models = [
......
...@@ -19,12 +19,11 @@ __all__ = [ ...@@ -19,12 +19,11 @@ __all__ = [
"MViT", "MViT",
"MViT_V1_B_Weights", "MViT_V1_B_Weights",
"mvit_v1_b", "mvit_v1_b",
"MViT_V2_S_Weights",
"mvit_v2_s",
] ]
# TODO: Consider handle 2d input if Temporal is 1
@dataclass @dataclass
class MSBlockConfig: class MSBlockConfig:
num_heads: int num_heads: int
...@@ -106,28 +105,121 @@ class Pool(nn.Module): ...@@ -106,28 +105,121 @@ class Pool(nn.Module):
return x, (T, H, W) return x, (T, H, W)
def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
if embedding.shape[0] == d:
return embedding
return (
nn.functional.interpolate(
embedding.permute(1, 0).unsqueeze(0),
size=d,
mode="linear",
)
.squeeze(0)
.permute(1, 0)
)
def _add_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
q_thw: Tuple[int, int, int],
k_thw: Tuple[int, int, int],
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
rel_pos_t: torch.Tensor,
) -> torch.Tensor:
# Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
q_t, q_h, q_w = q_thw
k_t, k_h, k_w = k_thw
dh = int(2 * max(q_h, k_h) - 1)
dw = int(2 * max(q_w, k_w) - 1)
dt = int(2 * max(q_t, k_t) - 1)
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
q_t_ratio = max(k_t / q_t, 1.0)
k_t_ratio = max(q_t / k_t, 1.0)
dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio
# Intepolate rel pos if needed.
rel_pos_h = _interpolate(rel_pos_h, dh)
rel_pos_w = _interpolate(rel_pos_w, dw)
rel_pos_t = _interpolate(rel_pos_t, dt)
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
Rt = rel_pos_t[dist_t.long()]
B, n_head, _, dim = q.shape
r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h]
rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w]
# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
# Combine rel pos.
rel_pos = (
rel_h_q[:, :, :, :, :, None, :, None]
+ rel_w_q[:, :, :, :, :, None, None, :]
+ rel_q_t[:, :, :, :, :, :, None, None]
).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)
# Add it to attention
attn[:, :, 1:, 1:] += rel_pos
return attn
def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
if residual_with_cls_embed:
x.add_(shortcut)
else:
x[:, :, 1:, :] += shortcut[:, :, 1:, :]
return x
torch.fx.wrap("_add_rel_pos")
torch.fx.wrap("_add_shortcut")
class MultiscaleAttention(nn.Module): class MultiscaleAttention(nn.Module):
def __init__( def __init__(
self, self,
input_size: List[int],
embed_dim: int, embed_dim: int,
output_dim: int,
num_heads: int, num_heads: int,
kernel_q: List[int], kernel_q: List[int],
kernel_kv: List[int], kernel_kv: List[int],
stride_q: List[int], stride_q: List[int],
stride_kv: List[int], stride_kv: List[int],
residual_pool: bool, residual_pool: bool,
residual_with_cls_embed: bool,
rel_pos_embed: bool,
dropout: float = 0.0, dropout: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm, norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.output_dim = output_dim
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = embed_dim // num_heads self.head_dim = output_dim // num_heads
self.scaler = 1.0 / math.sqrt(self.head_dim) self.scaler = 1.0 / math.sqrt(self.head_dim)
self.residual_pool = residual_pool self.residual_pool = residual_pool
self.residual_with_cls_embed = residual_with_cls_embed
self.qkv = nn.Linear(embed_dim, 3 * embed_dim) self.qkv = nn.Linear(embed_dim, 3 * output_dim)
layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)]
if dropout > 0.0: if dropout > 0.0:
layers.append(nn.Dropout(dropout, inplace=True)) layers.append(nn.Dropout(dropout, inplace=True))
self.project = nn.Sequential(*layers) self.project = nn.Sequential(*layers)
...@@ -177,24 +269,52 @@ class MultiscaleAttention(nn.Module): ...@@ -177,24 +269,52 @@ class MultiscaleAttention(nn.Module):
norm_layer(self.head_dim), norm_layer(self.head_dim),
) )
self.rel_pos_h: Optional[nn.Parameter] = None
self.rel_pos_w: Optional[nn.Parameter] = None
self.rel_pos_t: Optional[nn.Parameter] = None
if rel_pos_embed:
size = max(input_size[1:])
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
spatial_dim = 2 * max(q_size, kv_size) - 1
temporal_dim = 2 * input_size[0] - 1
self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
nn.init.trunc_normal_(self.rel_pos_t, std=0.02)
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
B, N, C = x.shape B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2)
if self.pool_k is not None: if self.pool_k is not None:
k = self.pool_k(k, thw)[0] k, k_thw = self.pool_k(k, thw)
else:
k_thw = thw
if self.pool_v is not None: if self.pool_v is not None:
v = self.pool_v(v, thw)[0] v = self.pool_v(v, thw)[0]
if self.pool_q is not None: if self.pool_q is not None:
q, thw = self.pool_q(q, thw) q, thw = self.pool_q(q, thw)
attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None:
attn = _add_rel_pos(
attn,
q,
thw,
k_thw,
self.rel_pos_h,
self.rel_pos_w,
self.rel_pos_t,
)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = torch.matmul(attn, v) x = torch.matmul(attn, v)
if self.residual_pool: if self.residual_pool:
x.add_(q) _add_shortcut(x, q, self.residual_with_cls_embed)
x = x.transpose(1, 2).reshape(B, -1, C) x = x.transpose(1, 2).reshape(B, -1, self.output_dim)
x = self.project(x) x = self.project(x)
return x, thw return x, thw
...@@ -203,13 +323,18 @@ class MultiscaleAttention(nn.Module): ...@@ -203,13 +323,18 @@ class MultiscaleAttention(nn.Module):
class MultiscaleBlock(nn.Module): class MultiscaleBlock(nn.Module):
def __init__( def __init__(
self, self,
input_size: List[int],
cnf: MSBlockConfig, cnf: MSBlockConfig,
residual_pool: bool, residual_pool: bool,
residual_with_cls_embed: bool,
rel_pos_embed: bool,
proj_after_attn: bool,
dropout: float = 0.0, dropout: float = 0.0,
stochastic_depth_prob: float = 0.0, stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm, norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
) -> None: ) -> None:
super().__init__() super().__init__()
self.proj_after_attn = proj_after_attn
self.pool_skip: Optional[nn.Module] = None self.pool_skip: Optional[nn.Module] = None
if _prod(cnf.stride_q) > 1: if _prod(cnf.stride_q) > 1:
...@@ -219,24 +344,30 @@ class MultiscaleBlock(nn.Module): ...@@ -219,24 +344,30 @@ class MultiscaleBlock(nn.Module):
nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type]
) )
attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels
self.norm1 = norm_layer(cnf.input_channels) self.norm1 = norm_layer(cnf.input_channels)
self.norm2 = norm_layer(cnf.input_channels) self.norm2 = norm_layer(attn_dim)
self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)
self.attn = MultiscaleAttention( self.attn = MultiscaleAttention(
input_size,
cnf.input_channels, cnf.input_channels,
attn_dim,
cnf.num_heads, cnf.num_heads,
kernel_q=cnf.kernel_q, kernel_q=cnf.kernel_q,
kernel_kv=cnf.kernel_kv, kernel_kv=cnf.kernel_kv,
stride_q=cnf.stride_q, stride_q=cnf.stride_q,
stride_kv=cnf.stride_kv, stride_kv=cnf.stride_kv,
rel_pos_embed=rel_pos_embed,
residual_pool=residual_pool, residual_pool=residual_pool,
residual_with_cls_embed=residual_with_cls_embed,
dropout=dropout, dropout=dropout,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
self.mlp = MLP( self.mlp = MLP(
cnf.input_channels, attn_dim,
[4 * cnf.input_channels, cnf.output_channels], [4 * attn_dim, cnf.output_channels],
activation_layer=nn.GELU, activation_layer=nn.GELU,
dropout=dropout, dropout=dropout,
inplace=None, inplace=None,
...@@ -249,36 +380,45 @@ class MultiscaleBlock(nn.Module): ...@@ -249,36 +380,45 @@ class MultiscaleBlock(nn.Module):
self.project = nn.Linear(cnf.input_channels, cnf.output_channels) self.project = nn.Linear(cnf.input_channels, cnf.output_channels)
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
x_attn, thw_new = self.attn(x_norm1, thw)
x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1)
x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0]
x = x_skip + self.stochastic_depth(x_attn)
x = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
x, thw = self.attn(x, thw) x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2)
x = x_skip + self.stochastic_depth(x)
x_norm = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new
x_proj = x if self.project is None else self.project(x_norm)
return x_proj + self.stochastic_depth(self.mlp(x_norm)), thw
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int) -> None: def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
super().__init__() super().__init__()
self.spatial_size = spatial_size self.spatial_size = spatial_size
self.temporal_size = temporal_size self.temporal_size = temporal_size
self.class_token = nn.Parameter(torch.zeros(embed_size)) self.class_token = nn.Parameter(torch.zeros(embed_size))
self.spatial_pos: Optional[nn.Parameter] = None
self.temporal_pos: Optional[nn.Parameter] = None
self.class_pos: Optional[nn.Parameter] = None
if not rel_pos_embed:
self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
self.class_pos = nn.Parameter(torch.zeros(embed_size)) self.class_pos = nn.Parameter(torch.zeros(embed_size))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
x = torch.cat((class_token, x), dim=1)
if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
hw_size, embed_size = self.spatial_pos.shape hw_size, embed_size = self.spatial_pos.shape
pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) x.add_(pos_embedding)
return torch.cat((class_token, x), dim=1).add_(pos_embedding)
return x
class MViT(nn.Module): class MViT(nn.Module):
...@@ -288,12 +428,18 @@ class MViT(nn.Module): ...@@ -288,12 +428,18 @@ class MViT(nn.Module):
temporal_size: int, temporal_size: int,
block_setting: Sequence[MSBlockConfig], block_setting: Sequence[MSBlockConfig],
residual_pool: bool, residual_pool: bool,
residual_with_cls_embed: bool,
rel_pos_embed: bool,
proj_after_attn: bool,
dropout: float = 0.5, dropout: float = 0.5,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0, stochastic_depth_prob: float = 0.0,
num_classes: int = 400, num_classes: int = 400,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7),
patch_embed_stride: Tuple[int, int, int] = (2, 4, 4),
patch_embed_padding: Tuple[int, int, int] = (1, 3, 3),
) -> None: ) -> None:
""" """
MViT main class. MViT main class.
...@@ -303,12 +449,19 @@ class MViT(nn.Module): ...@@ -303,12 +449,19 @@ class MViT(nn.Module):
temporal_size (int): The temporal size ``T`` of the input. temporal_size (int): The temporal size ``T`` of the input.
block_setting (sequence of MSBlockConfig): The Network structure. block_setting (sequence of MSBlockConfig): The Network structure.
residual_pool (bool): If True, use MViTv2 pooling residual connection. residual_pool (bool): If True, use MViTv2 pooling residual connection.
residual_with_cls_embed (bool): If True, the addition on the residual connection will include
the class embedding.
rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings.
proj_after_attn (bool): If True, apply the projection after the attention.
dropout (float): Dropout rate. Default: 0.0. dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
num_classes (int): The number of classes. num_classes (int): The number of classes.
block (callable, optional): Module specifying the layer which consists of the attention and mlp. block (callable, optional): Module specifying the layer which consists of the attention and mlp.
norm_layer (callable, optional): Module specifying the normalization layer to use. norm_layer (callable, optional): Module specifying the normalization layer to use.
patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input.
patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input.
patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input.
""" """
super().__init__() super().__init__()
# This implementation employs a different parameterization scheme than the one used at PyTorch Video: # This implementation employs a different parameterization scheme than the one used at PyTorch Video:
...@@ -330,16 +483,19 @@ class MViT(nn.Module): ...@@ -330,16 +483,19 @@ class MViT(nn.Module):
self.conv_proj = nn.Conv3d( self.conv_proj = nn.Conv3d(
in_channels=3, in_channels=3,
out_channels=block_setting[0].input_channels, out_channels=block_setting[0].input_channels,
kernel_size=(3, 7, 7), kernel_size=patch_embed_kernel,
stride=(2, 4, 4), stride=patch_embed_stride,
padding=(1, 3, 3), padding=patch_embed_padding,
) )
input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)]
# Spatio-Temporal Class Positional Encoding # Spatio-Temporal Class Positional Encoding
self.pos_encoding = PositionalEncoding( self.pos_encoding = PositionalEncoding(
embed_size=block_setting[0].input_channels, embed_size=block_setting[0].input_channels,
spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), spatial_size=(input_size[1], input_size[2]),
temporal_size=temporal_size // self.conv_proj.stride[0], temporal_size=input_size[0],
rel_pos_embed=rel_pos_embed,
) )
# Encoder module # Encoder module
...@@ -350,13 +506,20 @@ class MViT(nn.Module): ...@@ -350,13 +506,20 @@ class MViT(nn.Module):
self.blocks.append( self.blocks.append(
block( block(
input_size=input_size,
cnf=cnf, cnf=cnf,
residual_pool=residual_pool, residual_pool=residual_pool,
residual_with_cls_embed=residual_with_cls_embed,
rel_pos_embed=rel_pos_embed,
proj_after_attn=proj_after_attn,
dropout=attention_dropout, dropout=attention_dropout,
stochastic_depth_prob=sd_prob, stochastic_depth_prob=sd_prob,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
) )
if len(cnf.stride_q) > 0:
input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)]
self.norm = norm_layer(block_setting[-1].output_channels) self.norm = norm_layer(block_setting[-1].output_channels)
# Classifier module # Classifier module
...@@ -380,6 +543,8 @@ class MViT(nn.Module): ...@@ -380,6 +543,8 @@ class MViT(nn.Module):
nn.init.trunc_normal_(weights, std=0.02) nn.init.trunc_normal_(weights, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)
x = _unsqueeze(x, 5, 2)[0]
# patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])
x = self.conv_proj(x) x = self.conv_proj(x)
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
...@@ -420,6 +585,9 @@ def _mvit( ...@@ -420,6 +585,9 @@ def _mvit(
temporal_size=temporal_size, temporal_size=temporal_size,
block_setting=block_setting, block_setting=block_setting,
residual_pool=kwargs.pop("residual_pool", False), residual_pool=kwargs.pop("residual_pool", False),
residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True),
rel_pos_embed=kwargs.pop("rel_pos_embed", False),
proj_after_attn=kwargs.pop("proj_after_attn", False),
stochastic_depth_prob=stochastic_depth_prob, stochastic_depth_prob=stochastic_depth_prob,
**kwargs, **kwargs,
) )
...@@ -461,6 +629,37 @@ class MViT_V1_B_Weights(WeightsEnum): ...@@ -461,6 +629,37 @@ class MViT_V1_B_Weights(WeightsEnum):
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
class MViT_V2_S_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",
transforms=partial(
VideoClassification,
crop_size=(224, 224),
resize_size=(256,),
mean=(0.45, 0.45, 0.45),
std=(0.225, 0.225, 0.225),
),
meta={
"min_size": (224, 224),
"min_temporal_size": 16,
"categories": _KINETICS400_CATEGORIES,
"recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
),
"num_params": 34537744,
"_metrics": {
"Kinetics-400": {
"acc@1": 80.757,
"acc@5": 94.665,
}
},
},
)
DEFAULT = KINETICS400_V1
@register_model() @register_model()
def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
""" """
...@@ -548,6 +747,138 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T ...@@ -548,6 +747,138 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T
temporal_size=16, temporal_size=16,
block_setting=block_setting, block_setting=block_setting,
residual_pool=False, residual_pool=False,
residual_with_cls_embed=False,
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
weights=weights,
progress=progress,
**kwargs,
)
@register_model()
def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
"""
Constructs a small MViTV2 architecture from
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
Args:
weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.MViT_V2_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.MViT``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
:members:
"""
weights = MViT_V2_S_Weights.verify(weights)
config: Dict[str, List] = {
"num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
"input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768],
"output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
"kernel_q": [
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
],
"kernel_kv": [
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
],
"stride_q": [
[1, 1, 1],
[1, 2, 2],
[1, 1, 1],
[1, 2, 2],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 2, 2],
[1, 1, 1],
],
"stride_kv": [
[1, 8, 8],
[1, 4, 4],
[1, 4, 4],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 2, 2],
[1, 1, 1],
[1, 1, 1],
],
}
block_setting = []
for i in range(len(config["num_heads"])):
block_setting.append(
MSBlockConfig(
num_heads=config["num_heads"][i],
input_channels=config["input_channels"][i],
output_channels=config["output_channels"][i],
kernel_q=config["kernel_q"][i],
kernel_kv=config["kernel_kv"][i],
stride_q=config["stride_q"][i],
stride_kv=config["stride_kv"][i],
)
)
return _mvit(
spatial_size=(224, 224),
temporal_size=16,
block_setting=block_setting,
residual_pool=True,
residual_with_cls_embed=False,
rel_pos_embed=True,
proj_after_attn=True,
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
weights=weights, weights=weights,
progress=progress, progress=progress,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment