Commit ada4710d authored by Tri Dao's avatar Tri Dao
Browse files

[ViT] Run black on vit.py

parent a81900d4
...@@ -2,26 +2,21 @@ ...@@ -2,26 +2,21 @@
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math import math
import re import re
from functools import partial
from copy import deepcopy
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from einops import rearrange from einops import rearrange
from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.layer_norm import dropout_add_layer_norm
...@@ -29,11 +24,18 @@ except ImportError: ...@@ -29,11 +24,18 @@ except ImportError:
dropout_add_layer_norm = None dropout_add_layer_norm = None
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, def create_mixer_cls(
cross_attn=False): num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, qkv_proj_bias=qkv_bias, ):
dropout=attn_drop, fused_bias_fc=fused_bias_fc, mixer_cls = partial(
use_flash_attn=use_flash_attn) MHA,
num_heads=num_heads,
cross_attn=cross_attn,
qkv_proj_bias=qkv_bias,
dropout=attn_drop,
fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn,
)
return mixer_cls return mixer_cls
...@@ -46,47 +48,78 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): ...@@ -46,47 +48,78 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
return mlp_cls return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, def create_block(
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc, embed_dim,
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None, num_heads,
last_layer_subset=False): mlp_ratio,
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc, qkv_bias,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1)) drop_rate,
attn_drop_rate,
drop_path1,
drop_path2,
norm_layer,
act_layer,
use_flash_attn,
fused_bias_fc,
fused_mlp,
fused_dropout_add_ln,
layer_idx=None,
n_layer=None,
last_layer_subset=False,
):
mixer_cls = create_mixer_cls(
num_heads,
qkv_bias,
attn_drop_rate,
use_flash_attn,
fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1),
)
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, block = Block(
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate, embed_dim,
drop_path1=drop_path1, drop_path2=drop_path2, mixer_cls,
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True) mlp_cls,
norm_cls=norm_layer,
prenorm=True,
resid_dropout1=drop_rate,
resid_dropout2=drop_rate,
drop_path1=drop_path1,
drop_path2=drop_path2,
fused_dropout_add_ln=fused_dropout_add_ln,
residual_in_fp32=True,
)
return block return block
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
""" Vision Transformer """Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
""" """
def __init__( def __init__(
self, self,
img_size=224, img_size=224,
patch_size=16, patch_size=16,
in_chans=3, in_chans=3,
num_classes=1000, num_classes=1000,
global_pool='token', global_pool="token",
embed_dim=768, embed_dim=768,
depth=12, depth=12,
num_heads=12, num_heads=12,
mlp_ratio=4., mlp_ratio=4.0,
qkv_bias=True, qkv_bias=True,
init_values=None, init_values=None,
class_token=True, class_token=True,
no_embed_class=False, no_embed_class=False,
pre_norm=False, pre_norm=False,
fc_norm=None, fc_norm=None,
drop_rate=0., drop_rate=0.0,
attn_drop_rate=0., attn_drop_rate=0.0,
drop_path_rate=0., drop_path_rate=0.0,
weight_init='', weight_init="",
embed_layer=PatchEmbed, embed_layer=PatchEmbed,
norm_layer=None, norm_layer=None,
act_layer=None, act_layer=None,
...@@ -119,40 +152,45 @@ class VisionTransformer(nn.Module): ...@@ -119,40 +152,45 @@ class VisionTransformer(nn.Module):
act_layer: (nn.Module): MLP activation layer act_layer: (nn.Module): MLP activation layer
""" """
super().__init__() super().__init__()
assert global_pool == 'token', 'Only support pooling with CLS token' assert global_pool == "token", "Only support pooling with CLS token"
assert class_token assert class_token
assert init_values is None, 'LayerScale is not supported yet' assert init_values is None, "LayerScale is not supported yet"
assert weight_init == '' assert weight_init == ""
assert fc_norm is None assert fc_norm is None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert not pre_norm assert not pre_norm
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU act_layer = act_layer or nn.GELU
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class self.no_embed_class = no_embed_class
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed patch_embed_extra_kwargs = (
else {}) {"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {}
)
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**patch_embed_extra_kwargs **patch_embed_extra_kwargs,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
# We change the order of dropout, residual and layer norm: # We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do: # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
...@@ -160,31 +198,47 @@ class VisionTransformer(nn.Module): ...@@ -160,31 +198,47 @@ class VisionTransformer(nn.Module):
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed. # nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm. # This is for performance reason: we can fuse dropout + add + layer_norm.
self.blocks = nn.ModuleList([create_block( self.blocks = nn.ModuleList(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, [
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i], create_block(
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, embed_dim,
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp, num_heads,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth, mlp_ratio,
last_layer_subset=(global_pool == 'token') qkv_bias,
) for i in range(depth)]) drop_rate,
attn_drop_rate,
drop_path1=dpr[i - 1] if i > 0 else 0.0,
drop_path2=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc,
fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln,
layer_idx=i,
n_layer=depth,
last_layer_subset=(global_pool == "token"),
)
for i in range(depth)
]
)
self.dropout = nn.Dropout(p=drop_rate) self.dropout = nn.Dropout(p=drop_rate)
self.drop_path = StochasticDepth(p=dpr[-1], mode='row') self.drop_path = StochasticDepth(p=dpr[-1], mode="row")
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None: if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError("dropout_add_layer_norm is not installed")
# Classifier Head # Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init) self.init_weights(weight_init)
def init_weights(self, mode=''): def init_weights(self, mode=""):
assert mode == '' assert mode == ""
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=0.02)
if self.cls_token is not None: if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6) nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self) named_apply(init_weights_vit_timm, self)
...@@ -195,7 +249,7 @@ class VisionTransformer(nn.Module): ...@@ -195,7 +249,7 @@ class VisionTransformer(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {"pos_embed", "cls_token"}
def _pos_embed(self, x): def _pos_embed(self, x):
if self.no_embed_class: if self.no_embed_class:
...@@ -220,7 +274,7 @@ class VisionTransformer(nn.Module): ...@@ -220,7 +274,7 @@ class VisionTransformer(nn.Module):
x = self.patch_embed(x) x = self.patch_embed(x)
hidden_states = self._pos_embed(x) hidden_states = self._pos_embed(x)
residual = None residual = None
if self.global_pool != 'token' or all_tokens: if self.global_pool != "token" or all_tokens:
# if True: # if True:
for block in self.blocks: for block in self.blocks:
hidden_states, residual = block(hidden_states, residual) hidden_states, residual = block(hidden_states, residual)
...@@ -229,8 +283,9 @@ class VisionTransformer(nn.Module): ...@@ -229,8 +283,9 @@ class VisionTransformer(nn.Module):
hidden_states, residual = block(hidden_states, residual) hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention # For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence. # where the query is the 1st token and the key/value is the whole sequence.
hidden_states, residual = self.blocks[-1](hidden_states, residual, hidden_states, residual = self.blocks[-1](
mixer_subset=slice(0, 1)) hidden_states, residual, mixer_subset=slice(0, 1)
)
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
...@@ -238,21 +293,30 @@ class VisionTransformer(nn.Module): ...@@ -238,21 +293,30 @@ class VisionTransformer(nn.Module):
if self.drop_path.p == 0 or not self.training: if self.drop_path.p == 0 or not self.training:
rowscale = None rowscale = None
else: else:
rowscale = self.drop_path(torch.ones( rowscale = self.drop_path(
hidden_states.shape[:-1], device=hidden_states.device, torch.ones(
dtype=hidden_states.dtype) hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
) )
# Set prenorm=False here since we don't need to the residual # Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm( hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.norm.weight, self.norm.bias, hidden_states,
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale, residual,
prenorm=False, residual_in_fp32=True self.norm.weight,
self.norm.bias,
self.dropout.p if self.training else 0.0,
self.norm.eps,
rowscale=rowscale,
prenorm=False,
residual_in_fp32=True,
) )
return hidden_states return hidden_states
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
def forward(self, x): def forward(self, x):
...@@ -261,41 +325,46 @@ class VisionTransformer(nn.Module): ...@@ -261,41 +325,46 @@ class VisionTransformer(nn.Module):
return x return x
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight'] patch_embed_weight = state_dict["patch_embed.proj.weight"]
if patch_embed_weight.dim() == 4: if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear # convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight, state_dict["patch_embed.proj.weight"] = rearrange(
'o c h w -> o (c h w)') patch_embed_weight, "o c h w -> o (c h w)"
)
def key_mapping_attn(key): def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key) key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key) key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks) n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer) # Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn if (
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict): self.blocks[-1].mixer.cross_attn
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight') and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias') ):
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim] Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight")
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:] bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias")
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim] state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:] state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :]
state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim]
state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :]
return super().load_state_dict(state_dict, strict=strict) return super().load_state_dict(state_dict, strict=strict)
def init_weights_vit_timm(module: nn.Module, name: str = ''): def init_weights_vit_timm(module: nn.Module, name: str = ""):
""" ViT weight initialization, original timm impl (for reproducibility) """ """ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02) trunc_normal_(module.weight, std=0.02)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'): elif hasattr(module, "init_weights"):
module.init_weights() module.init_weights()
def vit_base_patch16_224(pretrained=False, **kwargs): def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
assert not pretrained assert not pretrained
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment