Commit 5fb6df0e authored by Tri Dao's avatar Tri Dao
Browse files

Implement BERT

parent dc24c226
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis from flash_attn.bert_padding import unpad_input, pad_input
class FlashAttention(nn.Module): class FlashAttention(nn.Module):
......
# Copyright (c) 2022, Tri Dao.
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
import re
import logging
from functools import partial
from collections.abc import Sequence
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig
from einops import rearrange
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
try:
from flash_attn.ops.fused_dense import FusedDenseTD
except ImportError:
FusedDenseTD = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
except ImportError:
dropout_add_layer_norm, layer_norm = None, None
try:
from flash_attn.losses.cross_entropy_apex import CrossEntropyLossApex
except ImportError:
CrossEntropyLossApex = None
logger = logging.getLogger(__name__)
def create_mixer_cls(config):
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads,
dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
return mixer_cls
def create_mlp_cls(config, layer_idx=None):
inner_dim = config.intermediate_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
if not fused_dense_gelu_dense:
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh'))
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
return mlp_cls
def create_block(config, layer_idx=None):
mixer_cls = create_mixer_cls(config)
mlp_cls = create_mlp_cls(config, layer_idx)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=False, resid_dropout=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
return block
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
def _init_weights(module, initializer_range=0.02):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if module.padding_idx is not None:
nn.init.zeros_(module.weight[module.padding_idx])
class BertEncoder(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.use_flash_attn = getattr(config, 'use_flash_attn', False)
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)])
def forward(self, hidden_states, key_padding_mask=None):
if key_padding_mask is None or not self.use_flash_attn:
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
if key_padding_mask is not None else None)
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
return hidden_states
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states, pool=True):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None:
raise ImportError('fused_dense is not installed')
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.transform_act_fn = nn.GELU(approximate='tanh')
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states)
else:
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
self.layer_norm.eps)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
# Instantiate model.
model = cls(config, *inputs, **kwargs)
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
config), strict=False)
logger.info(load_return)
return model
class BertModel(BertPreTrainedModel):
def __init__(self, config: BertConfig, add_pooling_layer=True):
super().__init__(config)
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0:
config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple))
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
assert config.position_embedding_type == 'absolute'
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings, config.type_vocab_size,
padding_idx=config.pad_token_id)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
masked_tokens_mask=None):
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32
if not self.fused_dropout_add_ln:
hidden_states = self.emb_drop(hidden_states)
hidden_states = self.emb_ln(hidden_states)
else:
hidden_states = dropout_add_layer_norm(
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
)
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return sequence_output, pooled_output
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config: BertConfig):
super().__init__(config)
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# (around 15%) to the classifier heads.
self.dense_seq_output = getattr(config, 'dense_seq_output', False)
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
assert not self.last_layer_subset, 'last_layer_subset is not implemented yet'
use_xentropy = getattr(config, 'use_xentropy', False)
if use_xentropy and CrossEntropyLossApex is None:
raise ImportError('xentropy_cuda is not installed')
loss_cls = nn.CrossEntropyLoss if not use_xentropy else CrossEntropyLossApex
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
self.mlm_loss = loss_cls(ignore_index=0)
self.nsp_loss = loss_cls(ignore_index=-1)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
labels=None, next_sentence_label=None):
"""
Outputs:
if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
"""
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
sequence_output, pooled_output = self.bert(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask.bool(), masked_tokens_mask=masked_tokens_mask
)
if self.dense_seq_output and labels is not None:
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
if not self.last_layer_subset:
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
masked_token_idx)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if labels is not None and next_sentence_label is not None:
if masked_token_idx is not None: # prediction_scores are already flattened
masked_lm_loss = self.mlm_loss(prediction_scores,
labels.flatten()[masked_token_idx])
else:
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
rearrange(labels, '... -> (...)'))
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
rearrange(next_sentence_label, '... -> (...)'))
total_loss = (masked_lm_loss + next_sentence_loss).float()
# Masked Language Model Accuracy
masked_lm_labels_flat = labels.view(-1)
mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != 0]
if not self.dense_seq_output:
prediction_scores_flat = rearrange(prediction_scores, '... v -> (...) v')
mlm_predictions_scores = prediction_scores_flat[masked_lm_labels_flat != 0]
mlm_predictions = mlm_predictions_scores.argmax(dim=-1)
else:
mlm_predictions = prediction_scores.argmax(dim=-1)
mlm_acc = (mlm_predictions == mlm_labels).sum(dtype=torch.float) / mlm_labels.numel()
return total_loss, prediction_scores, seq_relationship_score, mlm_acc, mlm_labels.numel()
else:
return prediction_scores, seq_relationship_score
def state_dict_from_pretrained(model_name):
from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file
return torch.load(cached_file(model_name, WEIGHTS_NAME))
def remap_state_dict(state_dict, config):
# LayerNorm
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
return key
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
# Layers
def key_mapping_layers(key):
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
r'bert.encoder.layers.\1.norm1.\2', key)
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
r'bert.encoder.layers.\1.norm2.\2', key)
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
r'cls.predictions.transform.layer_norm.\1', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
r'bert.encoder.layers.\1.mlp.fc1.\2', key)
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mlp.fc2.\2', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for d in range(config.num_hidden_layers):
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat(
[bq, bk, bv], dim=0
)
def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
def key_mapping_decoder_bias(key):
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
return state_dict
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers import GPT2Config
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
......
...@@ -23,10 +23,16 @@ class Block(nn.Module): ...@@ -23,10 +23,16 @@ class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
fused_dropout_add_ln=False): fused_dropout_add_ln=False, return_residual=False):
"""
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super().__init__() super().__init__()
self.prenorm = prenorm self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln self.fused_dropout_add_ln = fused_dropout_add_ln
self.return_residual = return_residual
if mixer_cls is None: if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64) mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None: if mlp_cls is None:
...@@ -92,8 +98,11 @@ class Block(nn.Module): ...@@ -92,8 +98,11 @@ class Block(nn.Module):
return hidden_states, residual return hidden_states, residual
else: else:
assert residual is None assert residual is None
mixer_out = self.mixer(hidden_states, mixer_out = self.mixer(
**(mixer_kwargs if mixer_kwargs is not None else {})) hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
)
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
+ hidden_states).to(dtype=self.norm1.weight.dtype)) + hidden_states).to(dtype=self.norm1.weight.dtype))
...@@ -111,6 +120,8 @@ class Block(nn.Module): ...@@ -111,6 +120,8 @@ class Block(nn.Module):
) )
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states) mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
+ hidden_states).to(dtype=self.norm2.weight.dtype)) + hidden_states).to(dtype=self.norm2.weight.dtype))
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import repeat
class GPT2Embeddings(nn.Module): class GPT2Embeddings(nn.Module):
...@@ -21,15 +19,51 @@ class GPT2Embeddings(nn.Module): ...@@ -21,15 +19,51 @@ class GPT2Embeddings(nn.Module):
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
return embeddings
class BertEmbeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
padding_idx=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim)
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
input_embeddings = self.word_embeddings(input_ids) embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
if position_ids is None: if position_ids is None:
position_ids = repeat(torch.arange(seqlen, dtype=torch.long, position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
device=input_ids.device),
's -> b s', b=batch_size)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
return input_embeddings + position_embeddings embeddings = embeddings + position_embeddings
else: if self.type_vocab_size > 0:
return input_embeddings if token_type_ids is None:
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
...@@ -53,28 +53,49 @@ class FlashSelfAttention(nn.Module): ...@@ -53,28 +53,49 @@ class FlashSelfAttention(nn.Module):
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
self.triton = triton self.triton = triton
def forward(self, qkv): def forward(self, qkv, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
""" """
assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda assert qkv.is_cuda
batch_size, seqlen = qkv.shape[0], qkv.shape[1] unpadded = cu_seqlens is not None
if self.triton and (self.dropout_p == 0 or not self.training): # Triton version doesn't support dropout if unpadded:
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale) assert cu_seqlens.dtype == torch.int32
else: assert max_seqlen is not None
qkv = rearrange(qkv, 'b s ... -> (b s) ...') assert isinstance(max_seqlen, int)
max_s = seqlen return flash_attn_unpadded_qkvpacked_func(
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=self.causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) else:
return output batch_size, seqlen = qkv.shape[0], qkv.shape[1]
# Triton version doesn't support dropout
if self.triton and (self.dropout_p == 0 or not self.training):
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale)
else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_seqlen = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class FlashCrossAttention(nn.Module): class FlashCrossAttention(nn.Module):
...@@ -146,16 +167,24 @@ class SelfAttention(nn.Module): ...@@ -146,16 +167,24 @@ class SelfAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
def forward(self, qkv): def forward(self, qkv, key_padding_mask=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
""" """
batch_size, seqlen = qkv.shape[0], qkv.shape[1] batch_size, seqlen = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2) q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal: if self.causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' # "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float # So we have to construct the mask in float
...@@ -239,6 +268,7 @@ class MHA(nn.Module): ...@@ -239,6 +268,7 @@ class MHA(nn.Module):
self.causal = causal self.causal = causal
self.dwconv = dwconv self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.return_residual = return_residual self.return_residual = return_residual
self.checkpointing = checkpointing self.checkpointing = checkpointing
...@@ -279,12 +309,35 @@ class MHA(nn.Module): ...@@ -279,12 +309,35 @@ class MHA(nn.Module):
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def forward(self, x, x_kv=None): def forward(self, x, x_kv=None, cu_seqlens=None, max_seqlen=None, key_padding_mask=None):
""" """
Arguments: Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
""" """
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.use_flash_attn
assert not self.cross_attn, ('Unpadded FlashAttention code path for cross-attention'
'is not implemented yet')
assert not self.dwconv
assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.use_flash_attn
assert not self.cross_attn, ('Key padding mask code path for cross-attention'
'is not implemented yet')
if not self.cross_attn: if not self.cross_attn:
if not self.return_residual: if not self.return_residual:
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
...@@ -293,14 +346,15 @@ class MHA(nn.Module): ...@@ -293,14 +346,15 @@ class MHA(nn.Module):
if self.dwconv: if self.dwconv:
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads)
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv) qkv = self.rotary_emb(qkv)
extra_kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask})
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(qkv) context = self.inner_attn(qkv, **extra_kwargs)
else: else:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv) context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **extra_kwargs)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv)
else: else:
q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads) q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads)
kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d', kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d',
...@@ -313,7 +367,6 @@ class MHA(nn.Module): ...@@ -313,7 +367,6 @@ class MHA(nn.Module):
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(q, kv) context = self.inner_attn(q, kv)
else: else:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv) context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv)
out = self.out_proj(rearrange(context, 'b s h d -> b s (h d)')) out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
...@@ -200,6 +200,10 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -200,6 +200,10 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None, None) None, None)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
......
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import state_dict_from_pretrained
from flash_attn.models.bert import remap_state_dict
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name)
pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
model = BertForPreTraining(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
return key
pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
for k, v in pretrained_state_dict.items())
model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf.load_state_dict(pretrained_state_dict, strict=False)
model_hf.cuda().to(dtype=dtype)
return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask)
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 2 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 2 * (pooled_output_hf - pooled_output_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask)
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
# Need to zero out the padded tokens in the sequence before comparison.
sequence_output_hf[~attention_mask, :] = 0.0
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0
print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()
prediction_scores, seq_relationship_scores = model(input_ids, attention_mask=attention_mask)
# Need to zero out the padded tokens in the sequence before comparison.
prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref[~attention_mask, :] = 0.0
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.dense_seq_output = True
config.use_xentropy = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels[(torch.rand(batch_size, max_seqlen, device='cuda') < 0.15) | ~attention_mask] = 0
masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')
total_loss, prediction_scores, seq_relationship_scores, _, _ = model(
input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label
)
out_hf = model_hf(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask]
out_ref = model_ref(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask]
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment