"driver/include/tensor.hpp" did not exist on "2c9b8c2432ffe2eceba32d07ce8b0e467dd4538e"
Commit 1feb9426 authored by Tri Dao's avatar Tri Dao
Browse files

[ViT] Use dropout_add_ln for the 1st layer norm

parent 45bcf37b
...@@ -104,14 +104,14 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid ...@@ -104,14 +104,14 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
class GPT2Model(nn.Module): class GPTModel(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False) self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if self.pad_vocab_size_multiple_8: if config.vocab_size % self.pad_vocab_size_multiple != 0:
if config.vocab_size % 8 != 0: config.vocab_size += (self.pad_vocab_size_multiple
config.vocab_size += 8 - (config.vocab_size % 8) - (config.vocab_size % self.pad_vocab_size_multiple))
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size, self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings) config.max_position_embeddings)
...@@ -153,11 +153,11 @@ class GPT2Model(nn.Module): ...@@ -153,11 +153,11 @@ class GPT2Model(nn.Module):
return hidden_states return hidden_states
class GPT2LMHeadModel(nn.Module): class GPTLMHeadModel(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
self.transformer = GPT2Model(config) self.transformer = GPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
......
...@@ -18,6 +18,11 @@ from flash_attn.modules.mha import MHA ...@@ -18,6 +18,11 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
cross_attn=False): cross_attn=False):
...@@ -152,6 +157,10 @@ class VisionTransformer(nn.Module): ...@@ -152,6 +157,10 @@ class VisionTransformer(nn.Module):
# (in the pretrained weight) is the final layer norm. # (in the pretrained weight) is the final layer norm.
self.norm_0 = norm_layer(embed_dim) self.norm_0 = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
self.blocks = nn.ModuleList([create_block( self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i], embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
...@@ -193,7 +202,7 @@ class VisionTransformer(nn.Module): ...@@ -193,7 +202,7 @@ class VisionTransformer(nn.Module):
if self.cls_token is not None: if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed x = x + self.pos_embed
return self.pos_drop(x) return x
def forward_features(self, x, all_tokens=True): def forward_features(self, x, all_tokens=True):
""" """
...@@ -201,8 +210,17 @@ class VisionTransformer(nn.Module): ...@@ -201,8 +210,17 @@ class VisionTransformer(nn.Module):
cls token. cls token.
""" """
x = self.patch_embed(x) x = self.patch_embed(x)
x = self._pos_embed(x)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
residual = self._pos_embed(x).float() if not self.fused_dropout_add_ln:
residual = self.pos_drop(x).float()
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
else:
hidden_states, residual = dropout_add_layer_norm(
x, None, self.norm_0.weight, self.norm_0.bias,
self.pos_drop.p if self.training else 0.0, self.norm_0.eps, prenorm=True,
residual_in_fp32=True
)
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype)) hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
if self.global_pool != 'token' or all_tokens: if self.global_pool != 'token' or all_tokens:
for block in self.blocks: for block in self.blocks:
......
...@@ -64,7 +64,6 @@ class FusedDenseGeluDense(nn.Module): ...@@ -64,7 +64,6 @@ class FusedDenseGeluDense(nn.Module):
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x): def forward(self, x):
assert x.dtype in [torch.float16, torch.bfloat16]
assert x.is_cuda assert x.is_cuda
fn = (fused_dense_gelu_dense_function_td if not self.return_residual fn = (fused_dense_gelu_dense_function_td if not self.return_residual
else fused_dense_res_gelu_dense_function_td) else fused_dense_res_gelu_dense_function_td)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment