Commit 13cdceb3 authored by Tri Dao's avatar Tri Dao
Browse files

Implement last_layer_subset optimization for BERT

parent 5fb6df0e
...@@ -17,6 +17,8 @@ import torch.nn as nn ...@@ -17,6 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BertConfig from transformers import BertConfig
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange from einops import rearrange
...@@ -24,7 +26,8 @@ from flash_attn.modules.mha import MHA ...@@ -24,7 +26,8 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
try: try:
from flash_attn.ops.fused_dense import FusedDenseTD from flash_attn.ops.fused_dense import FusedDenseTD
...@@ -45,21 +48,27 @@ except ImportError: ...@@ -45,21 +48,27 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_mixer_cls(config): def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False) use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob, causal=False, dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn) fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
return_residual=return_residual)
return mixer_cls return mixer_cls
def create_mlp_cls(config, layer_idx=None): def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size inner_dim = config.intermediate_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
if fused_dense_gelu_dense:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
'supports approximate gelu')
if not fused_dense_gelu_dense: if not fused_dense_gelu_dense:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim, mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh')) activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual)
else: else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...@@ -67,17 +76,24 @@ def create_mlp_cls(config, layer_idx=None): ...@@ -67,17 +76,24 @@ def create_mlp_cls(config, layer_idx=None):
assert layer_idx is not None assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl) checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
return mlp_cls return mlp_cls
def create_block(config, layer_idx=None): def create_block(config, layer_idx=None):
mixer_cls = create_mixer_cls(config) last_layer_subset = getattr(config, 'last_layer_subset', False)
mlp_cls = create_mlp_cls(config, layer_idx) cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case.
return_residual = not cross_attn
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=False, resid_dropout=config.hidden_dropout_prob, prenorm=False, resid_dropout=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False)) fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
return_residual=return_residual)
return block return block
...@@ -101,21 +117,49 @@ class BertEncoder(nn.Module): ...@@ -101,21 +117,49 @@ class BertEncoder(nn.Module):
self.layers = nn.ModuleList([create_block(config, layer_idx=i) self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)]) for i in range(config.num_hidden_layers)])
def forward(self, hidden_states, key_padding_mask=None): def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
"""If subset_mask is not None, we only want output for the subset of the sequence.
This means that we only compute the last layer output for these tokens.
subset_mask: (batch, seqlen), dtype=torch.bool
"""
if key_padding_mask is None or not self.use_flash_attn: if key_padding_mask is None or not self.use_flash_attn:
mixer_kwargs = ({'key_padding_mask': key_padding_mask} mixer_kwargs = ({'key_padding_mask': key_padding_mask}
if key_padding_mask is not None else None) if key_padding_mask is not None else None)
for layer in self.layers: for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if subset_mask is not None:
hidden_states = hidden_states[subset_mask]
else: else:
batch, seqlen = hidden_states.shape[:2] batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, key_padding_mask hidden_states, key_padding_mask
) )
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch} mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
for layer in self.layers: if subset_mask is None:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) for layer in self.layers:
hidden_states = pad_input(hidden_states, indices, batch, seqlen) hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
else:
for layer in self.layers[:-1]:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if key_padding_mask is not None:
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
else:
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
hidden_states_subset, hidden_states = index_first_axis_residual(
hidden_states, subset_idx
)
# It's ok to set max_seqlen_q to be much larger
mixer_kwargs = {'x_kv': hidden_states,
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
return hidden_states return hidden_states
...@@ -151,7 +195,8 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -151,7 +195,8 @@ class BertPredictionHeadTransform(nn.Module):
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
self.dense = linear_cls(config.hidden_size, config.hidden_size) self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.transform_act_fn = nn.GELU(approximate='tanh') approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
self.transform_act_fn = nn.GELU(approximate=approximate)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -264,6 +309,11 @@ class BertModel(BertPreTrainedModel): ...@@ -264,6 +309,11 @@ class BertModel(BertPreTrainedModel):
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
masked_tokens_mask=None): masked_tokens_mask=None):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
"""
hidden_states = self.embeddings(input_ids, position_ids=position_ids, hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32 # TD [2022-12:18]: Don't need to force residual in fp32
...@@ -275,9 +325,38 @@ class BertModel(BertPreTrainedModel): ...@@ -275,9 +325,38 @@ class BertModel(BertPreTrainedModel):
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias, hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False, self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
) )
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if masked_tokens_mask is not None:
return sequence_output, pooled_output batch_size, seqlen = input_ids.shape[:2]
# We also need the first column for the CLS token
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
device=input_ids.device)
first_col_mask[:, 0] = True
subset_mask = masked_tokens_mask | first_col_mask
else:
subset_mask = None
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
subset_mask=subset_mask)
if masked_tokens_mask is None:
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
else:
# TD [2022-03-01]: the indexing here is very tricky.
if attention_mask is not None:
subset_idx = subset_mask[attention_mask]
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
else:
pool_input = sequence_output[first_col_mask[subset_mask]]
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
pooled_output = (self.pooler(pool_input, pool=False)
if self.pooler is not None else None)
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
)
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
...@@ -290,11 +369,13 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -290,11 +369,13 @@ class BertForPreTraining(BertPreTrainedModel):
# If last_layer_subset, we only need the compute the last layer for a subset of tokens # If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction). # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self.last_layer_subset = getattr(config, 'last_layer_subset', False) self.last_layer_subset = getattr(config, 'last_layer_subset', False)
assert not self.last_layer_subset, 'last_layer_subset is not implemented yet' if self.last_layer_subset:
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
use_xentropy = getattr(config, 'use_xentropy', False) use_xentropy = getattr(config, 'use_xentropy', False)
if use_xentropy and CrossEntropyLossApex is None: if use_xentropy and CrossEntropyLossApex is None:
raise ImportError('xentropy_cuda is not installed') raise ImportError('xentropy_cuda is not installed')
loss_cls = nn.CrossEntropyLoss if not use_xentropy else CrossEntropyLossApex loss_cls = (nn.CrossEntropyLoss if not use_xentropy
else partial(CrossEntropyLossApex, inplace_backward=True))
self.bert = BertModel(config) self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config) self.cls = BertPreTrainingHeads(config)
...@@ -311,6 +392,8 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -311,6 +392,8 @@ class BertForPreTraining(BertPreTrainedModel):
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
labels=None, next_sentence_label=None): labels=None, next_sentence_label=None):
""" """
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask).
Outputs: Outputs:
if `labels` and `next_sentence_label` are not `None`: if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next Outputs the total_loss which is the sum of the masked language modeling loss and the next
...@@ -322,10 +405,12 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -322,10 +405,12 @@ class BertForPreTraining(BertPreTrainedModel):
""" """
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
sequence_output, pooled_output = self.bert( outputs = self.bert(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask.bool(), masked_tokens_mask=masked_tokens_mask attention_mask=attention_mask.bool() if attention_mask is not None else None,
masked_tokens_mask=masked_tokens_mask
) )
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
if self.dense_seq_output and labels is not None: if self.dense_seq_output and labels is not None:
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
if not self.last_layer_subset: if not self.last_layer_subset:
...@@ -333,8 +418,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -333,8 +418,9 @@ class BertForPreTraining(BertPreTrainedModel):
masked_token_idx) masked_token_idx)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None: if labels is not None and next_sentence_label is not None:
if masked_token_idx is not None: # prediction_scores are already flattened if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
masked_lm_loss = self.mlm_loss(prediction_scores, masked_lm_loss = self.mlm_loss(prediction_scores,
labels.flatten()[masked_token_idx]) labels.flatten()[masked_token_idx])
else: else:
...@@ -342,22 +428,13 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -342,22 +428,13 @@ class BertForPreTraining(BertPreTrainedModel):
rearrange(labels, '... -> (...)')) rearrange(labels, '... -> (...)'))
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'), next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
rearrange(next_sentence_label, '... -> (...)')) rearrange(next_sentence_label, '... -> (...)'))
total_loss = (masked_lm_loss + next_sentence_loss).float() total_loss = masked_lm_loss.float() + next_sentence_loss.float()
# Masked Language Model Accuracy
masked_lm_labels_flat = labels.view(-1)
mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != 0]
if not self.dense_seq_output:
prediction_scores_flat = rearrange(prediction_scores, '... v -> (...) v')
mlm_predictions_scores = prediction_scores_flat[masked_lm_labels_flat != 0]
mlm_predictions = mlm_predictions_scores.argmax(dim=-1)
else:
mlm_predictions = prediction_scores.argmax(dim=-1)
mlm_acc = (mlm_predictions == mlm_labels).sum(dtype=torch.float) / mlm_labels.numel()
return total_loss, prediction_scores, seq_relationship_score, mlm_acc, mlm_labels.numel() return BertForPreTrainingOutput(
else: loss=total_loss,
return prediction_scores, seq_relationship_score prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
)
def state_dict_from_pretrained(model_name): def state_dict_from_pretrained(model_name):
...@@ -401,6 +478,7 @@ def remap_state_dict(state_dict, config): ...@@ -401,6 +478,7 @@ def remap_state_dict(state_dict, config):
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
last_layer_subset = getattr(config, 'last_layer_subset', False)
for d in range(config.num_hidden_layers): for d in range(config.num_hidden_layers):
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight') Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight') Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
...@@ -408,12 +486,22 @@ def remap_state_dict(state_dict, config): ...@@ -408,12 +486,22 @@ def remap_state_dict(state_dict, config):
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias') bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias') bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias') bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat( if not (last_layer_subset and d == config.num_hidden_layers - 1):
[Wq, Wk, Wv], dim=0 state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
) [Wq, Wk, Wv], dim=0
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat( )
[bq, bk, bv], dim=0 state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat(
) [bq, bk, bv], dim=0
)
else:
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
[Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat(
[bk, bv], dim=0
)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)', return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mixer.out_proj.\2', key) r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
...@@ -423,4 +511,23 @@ def remap_state_dict(state_dict, config): ...@@ -423,4 +511,23 @@ def remap_state_dict(state_dict, config):
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key) return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
# Word embedding
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if pad_vocab_size_multiple > 1:
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
)
decoder_weight = state_dict['cls.predictions.decoder.weight']
state_dict['cls.predictions.decoder.weight'] = F.pad(
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
)
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias = state_dict['cls.predictions.decoder.bias']
state_dict['cls.predictions.decoder.bias'] = F.pad(
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
)
return state_dict return state_dict
...@@ -120,34 +120,55 @@ class FlashCrossAttention(nn.Module): ...@@ -120,34 +120,55 @@ class FlashCrossAttention(nn.Module):
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
self.triton = triton self.triton = triton
def forward(self, q, kv): def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
q: The tensor containing the query. (B, Sq, H, D) q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D)
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
""" """
assert q.dtype in [torch.float16, torch.bfloat16] assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda assert q.is_cuda and kv.is_cuda
batch_size, seqlen_q = q.shape[0], q.shape[1] unpadded = cu_seqlens is not None
seqlen_k = kv.shape[1] if unpadded:
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] assert cu_seqlens.dtype == torch.int32
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout assert max_seqlen is not None
output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale) assert isinstance(max_seqlen, int)
else: assert cu_seqlens_k is not None
q = rearrange(q, 'b s ... -> (b s) ...') assert cu_seqlens_k.dtype == torch.int32
kv = rearrange(kv, 'b s ... -> (b s) ...') assert max_seqlen_k is not None
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, assert isinstance(max_seqlen, int)
dtype=torch.int32, device=q.device) return flash_attn_unpadded_kvpacked_func(
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
dtype=torch.int32, device=kv.device)
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=self.causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) else:
return output batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale)
else:
q = rearrange(q, 'b s ... -> (b s) ...')
kv = rearrange(kv, 'b s ... -> (b s) ...')
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q,
dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k,
dtype=torch.int32, device=kv.device)
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
...@@ -214,12 +235,14 @@ class CrossAttention(nn.Module): ...@@ -214,12 +235,14 @@ class CrossAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
def forward(self, q, kv): def forward(self, q, kv, key_padding_mask=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
q: The tensor containing the query. (B, Sq, H, D) q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D)
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
""" """
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1] seqlen_k = kv.shape[1]
...@@ -227,6 +250,12 @@ class CrossAttention(nn.Module): ...@@ -227,6 +250,12 @@ class CrossAttention(nn.Module):
k, v = kv.unbind(dim=2) k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal: if self.causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' # "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float # So we have to construct the mask in float
...@@ -295,9 +324,11 @@ class MHA(nn.Module): ...@@ -295,9 +324,11 @@ class MHA(nn.Module):
groups=3 * embed_dim) groups=3 * embed_dim)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
else: else:
# TODO: use the residual linear class for Wq
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
else:
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
if self.dwconv: if self.dwconv:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim) groups=embed_dim)
...@@ -309,7 +340,8 @@ class MHA(nn.Module): ...@@ -309,7 +340,8 @@ class MHA(nn.Module):
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def forward(self, x, x_kv=None, cu_seqlens=None, max_seqlen=None, key_padding_mask=None): def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
**kwargs):
""" """
Arguments: Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...@@ -327,17 +359,15 @@ class MHA(nn.Module): ...@@ -327,17 +359,15 @@ class MHA(nn.Module):
assert max_seqlen is not None assert max_seqlen is not None
assert key_padding_mask is None assert key_padding_mask is None
assert self.use_flash_attn assert self.use_flash_attn
assert not self.cross_attn, ('Unpadded FlashAttention code path for cross-attention'
'is not implemented yet')
assert not self.dwconv assert not self.dwconv
assert self.rotary_emb_dim == 0 assert self.rotary_emb_dim == 0
if key_padding_mask is not None: if key_padding_mask is not None:
assert cu_seqlens is None assert cu_seqlens is None
assert max_seqlen is None assert max_seqlen is None
assert not self.use_flash_attn assert not self.use_flash_attn
assert not self.cross_attn, ('Key padding mask code path for cross-attention'
'is not implemented yet')
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
if not self.cross_attn: if not self.cross_attn:
if not self.return_residual: if not self.return_residual:
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
...@@ -349,24 +379,30 @@ class MHA(nn.Module): ...@@ -349,24 +379,30 @@ class MHA(nn.Module):
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads)
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv) qkv = self.rotary_emb(qkv)
extra_kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask})
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(qkv, **extra_kwargs) context = self.inner_attn(qkv, **kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **extra_kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else: else:
q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads) if not self.return_residual:
kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d', q = self.Wq(x)
two=2, h=self.num_heads) kv = self.Wkv(x_kv if x_kv is not None else x)
else:
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
kv, x = self.Wkv(x)
q = self.Wq(x)
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, h=self.num_heads)
if self.dwconv: if self.dwconv:
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2], q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2], kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(q, kv) context = self.inner_attn(q, kv, **kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv) context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
...@@ -15,20 +15,21 @@ except ImportError: ...@@ -15,20 +15,21 @@ except ImportError:
class Mlp(nn.Module): class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
device=None, dtype=None): return_residual=False, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
def forward(self, x): def forward(self, x):
x = self.fc1(x) y = self.fc1(x)
x = self.activation(x) y = self.activation(y)
x = self.fc2(x) y = self.fc2(y)
return x return y if not self.return_residual else (y, x)
class FusedDenseGeluDense(nn.Module): class FusedDenseGeluDense(nn.Module):
......
...@@ -53,15 +53,12 @@ def test_bert_non_optimized(model_name): ...@@ -53,15 +53,12 @@ def test_bert_non_optimized(model_name):
""" """
dtype = torch.float16 dtype = torch.float16
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
model = BertForPreTraining.from_pretrained(model_name, config) model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype) model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32) model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16) model_hf = get_hf_models(model_name, config, dtype)
model.eval() model.eval()
model_ref.eval() model_ref.eval()
...@@ -74,7 +71,8 @@ def test_bert_non_optimized(model_name): ...@@ -74,7 +71,8 @@ def test_bert_non_optimized(model_name):
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda') device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask) out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask) out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
...@@ -84,8 +82,8 @@ def test_bert_non_optimized(model_name): ...@@ -84,8 +82,8 @@ def test_bert_non_optimized(model_name):
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}') print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}') print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}') print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 2 * (sequence_output_hf - sequence_output_ref).abs().max().item() assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 2 * (pooled_output_hf - pooled_output_ref).abs().max().item() assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
...@@ -97,8 +95,9 @@ def test_bert_optimized(model_name): ...@@ -97,8 +95,9 @@ def test_bert_optimized(model_name):
""" """
dtype = torch.float16 dtype = torch.float16
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh') # Our implementation of fused_dense_gelu_dense assumes the activation is
# Huggingface calls it "gelu_new" or "gelu_fast". # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
config.hidden_act = "gelu_new" config.hidden_act = "gelu_new"
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -109,7 +108,7 @@ def test_bert_optimized(model_name): ...@@ -109,7 +108,7 @@ def test_bert_optimized(model_name):
model = model.cuda().to(dtype=dtype) model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32) model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16) model_hf = get_hf_models(model_name, config, dtype)
model.eval() model.eval()
model_ref.eval() model_ref.eval()
...@@ -122,7 +121,8 @@ def test_bert_optimized(model_name): ...@@ -122,7 +121,8 @@ def test_bert_optimized(model_name):
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda') device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask) out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
# Need to zero out the padded tokens in the sequence before comparison. # Need to zero out the padded tokens in the sequence before comparison.
...@@ -138,7 +138,8 @@ def test_bert_optimized(model_name): ...@@ -138,7 +138,8 @@ def test_bert_optimized(model_name):
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item() assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item() assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()
prediction_scores, seq_relationship_scores = model(input_ids, attention_mask=attention_mask) out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
# Need to zero out the padded tokens in the sequence before comparison. # Need to zero out the padded tokens in the sequence before comparison.
prediction_scores = prediction_scores.clone() prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0 prediction_scores[~attention_mask, :] = 0.0
...@@ -157,30 +158,36 @@ def test_bert_optimized(model_name): ...@@ -157,30 +158,36 @@ def test_bert_optimized(model_name):
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item() assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
@pytest.mark.parametrize('last_layer_subset', [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name): def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the """Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh') # Our implementation of fused_dense_gelu_dense assumes the activation is
# Huggingface calls it "gelu_new" or "gelu_fast". # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
config.hidden_act = "gelu_new" config.hidden_act = "gelu_new"
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.dense_seq_output = True config.dense_seq_output = True
config.last_layer_subset = last_layer_subset
config.use_xentropy = True config.use_xentropy = True
model = BertForPreTraining.from_pretrained(model_name, config) model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype) model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32) model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16) model_hf = get_hf_models(model_name, config, dtype)
model.eval() model.eval()
model_ref.eval() model_ref.eval()
...@@ -190,19 +197,25 @@ def test_bert_dense_seq_output(model_name): ...@@ -190,19 +197,25 @@ def test_bert_dense_seq_output(model_name):
batch_size = 4 batch_size = 4
max_seqlen = 512 max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None] if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
else:
attention_mask = None
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda') device='cuda')
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda') device='cuda')
labels[(torch.rand(batch_size, max_seqlen, device='cuda') < 0.15) | ~attention_mask] = 0 if attention_mask is not None:
labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0 masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda') next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')
total_loss, prediction_scores, seq_relationship_scores, _, _ = model( out = model(
input_ids, attention_mask=attention_mask, input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label labels=labels, next_sentence_label=next_sequence_label
) )
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(input_ids, attention_mask=attention_mask, out_hf = model_hf(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label) labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
...@@ -217,3 +230,6 @@ def test_bert_dense_seq_output(model_name): ...@@ -217,3 +230,6 @@ def test_bert_dense_seq_output(model_name):
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}') print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}') print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item() assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment