Commit ada4710d authored by Tri Dao's avatar Tri Dao
Browse files

[ViT] Run black on vit.py

parent a81900d4
......@@ -2,26 +2,21 @@
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
import re
from functools import partial
from copy import deepcopy
from collections import OrderedDict
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from einops import rearrange
from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp
from timm.models.helpers import named_apply
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
......@@ -29,11 +24,18 @@ except ImportError:
dropout_add_layer_norm = None
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
cross_attn=False):
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, qkv_proj_bias=qkv_bias,
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn)
def create_mixer_cls(
num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False
):
mixer_cls = partial(
MHA,
num_heads=num_heads,
cross_attn=cross_attn,
qkv_proj_bias=qkv_bias,
dropout=attn_drop,
fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn,
)
return mixer_cls
......@@ -46,54 +48,85 @@ def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
def create_block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias,
drop_rate,
attn_drop_rate,
drop_path1,
drop_path2,
norm_layer,
act_layer,
use_flash_attn,
fused_bias_fc,
fused_mlp,
fused_dropout_add_ln,
layer_idx=None,
n_layer=None,
last_layer_subset=False,
):
mixer_cls = create_mixer_cls(
num_heads,
qkv_bias,
attn_drop_rate,
use_flash_attn,
fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1),
)
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
drop_path1=drop_path1, drop_path2=drop_path2,
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True)
block = Block(
embed_dim,
mixer_cls,
mlp_cls,
norm_cls=norm_layer,
prenorm=True,
resid_dropout1=drop_rate,
resid_dropout2=drop_rate,
drop_path1=drop_path1,
drop_path2=drop_path2,
fused_dropout_add_ln=fused_dropout_add_ln,
residual_in_fp32=True,
)
return block
class VisionTransformer(nn.Module):
""" Vision Transformer
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_mlp=False,
fused_dropout_add_ln=False,
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool="token",
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
weight_init="",
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_mlp=False,
fused_dropout_add_ln=False,
):
"""
Args:
......@@ -119,40 +152,45 @@ class VisionTransformer(nn.Module):
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool == 'token', 'Only support pooling with CLS token'
assert global_pool == "token", "Only support pooling with CLS token"
assert class_token
assert init_values is None, 'LayerScale is not supported yet'
assert weight_init == ''
assert init_values is None, "LayerScale is not supported yet"
assert weight_init == ""
assert fc_norm is None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert not pre_norm
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
else {})
patch_embed_extra_kwargs = (
{"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {}
)
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**patch_embed_extra_kwargs
**patch_embed_extra_kwargs,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
......@@ -160,31 +198,47 @@ class VisionTransformer(nn.Module):
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token')
) for i in range(depth)])
self.blocks = nn.ModuleList(
[
create_block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias,
drop_rate,
attn_drop_rate,
drop_path1=dpr[i - 1] if i > 0 else 0.0,
drop_path2=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc,
fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln,
layer_idx=i,
n_layer=depth,
last_layer_subset=(global_pool == "token"),
)
for i in range(depth)
]
)
self.dropout = nn.Dropout(p=drop_rate)
self.drop_path = StochasticDepth(p=dpr[-1], mode='row')
self.drop_path = StochasticDepth(p=dpr[-1], mode="row")
self.norm = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
raise ImportError("dropout_add_layer_norm is not installed")
# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode == ''
trunc_normal_(self.pos_embed, std=.02)
def init_weights(self, mode=""):
assert mode == ""
trunc_normal_(self.pos_embed, std=0.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
......@@ -195,7 +249,7 @@ class VisionTransformer(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
return {"pos_embed", "cls_token"}
def _pos_embed(self, x):
if self.no_embed_class:
......@@ -220,8 +274,8 @@ class VisionTransformer(nn.Module):
x = self.patch_embed(x)
hidden_states = self._pos_embed(x)
residual = None
if self.global_pool != 'token' or all_tokens:
# if True:
if self.global_pool != "token" or all_tokens:
# if True:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
else:
......@@ -229,8 +283,9 @@ class VisionTransformer(nn.Module):
hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states, residual = self.blocks[-1](hidden_states, residual,
mixer_subset=slice(0, 1))
hidden_states, residual = self.blocks[-1](
hidden_states, residual, mixer_subset=slice(0, 1)
)
if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
......@@ -238,21 +293,30 @@ class VisionTransformer(nn.Module):
if self.drop_path.p == 0 or not self.training:
rowscale = None
else:
rowscale = self.drop_path(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
rowscale = self.drop_path(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.norm.weight, self.norm.bias,
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale,
prenorm=False, residual_in_fp32=True
hidden_states,
residual,
self.norm.weight,
self.norm.bias,
self.dropout.p if self.training else 0.0,
self.norm.eps,
rowscale=rowscale,
prenorm=False,
residual_in_fp32=True,
)
return hidden_states
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
......@@ -261,41 +325,46 @@ class VisionTransformer(nn.Module):
return x
def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight']
patch_embed_weight = state_dict["patch_embed.proj.weight"]
if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
'o c h w -> o (c h w)')
state_dict["patch_embed.proj.weight"] = rearrange(
patch_embed_weight, "o c h w -> o (c h w)"
)
def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key)
key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
if (
self.blocks[-1].mixer.cross_attn
and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict
):
Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight")
bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias")
state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim]
state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :]
state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim]
state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :]
return super().load_state_dict(state_dict, strict=strict)
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
elif hasattr(module, "init_weights"):
module.init_weights()
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
"""ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert not pretrained
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment