Commit 13cdceb3 authored by Tri Dao's avatar Tri Dao
Browse files

Implement last_layer_subset optimization for BERT

parent 5fb6df0e
......@@ -17,6 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange
......@@ -24,7 +26,8 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
try:
from flash_attn.ops.fused_dense import FusedDenseTD
......@@ -45,21 +48,27 @@ except ImportError:
logger = logging.getLogger(__name__)
def create_mixer_cls(config):
def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads,
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
return_residual=return_residual)
return mixer_cls
def create_mlp_cls(config, layer_idx=None):
def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
if fused_dense_gelu_dense:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
'supports approximate gelu')
if not fused_dense_gelu_dense:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh'))
activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual)
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
......@@ -67,17 +76,24 @@ def create_mlp_cls(config, layer_idx=None):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
return mlp_cls
def create_block(config, layer_idx=None):
mixer_cls = create_mixer_cls(config)
mlp_cls = create_mlp_cls(config, layer_idx)
last_layer_subset = getattr(config, 'last_layer_subset', False)
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case.
return_residual = not cross_attn
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=False, resid_dropout=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
return_residual=return_residual)
return block
......@@ -101,21 +117,49 @@ class BertEncoder(nn.Module):
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)])
def forward(self, hidden_states, key_padding_mask=None):
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
"""If subset_mask is not None, we only want output for the subset of the sequence.
This means that we only compute the last layer output for these tokens.
subset_mask: (batch, seqlen), dtype=torch.bool
"""
if key_padding_mask is None or not self.use_flash_attn:
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
if key_padding_mask is not None else None)
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if subset_mask is not None:
hidden_states = hidden_states[subset_mask]
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
if subset_mask is None:
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
else:
for layer in self.layers[:-1]:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if key_padding_mask is not None:
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
else:
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
hidden_states_subset, hidden_states = index_first_axis_residual(
hidden_states, subset_idx
)
# It's ok to set max_seqlen_q to be much larger
mixer_kwargs = {'x_kv': hidden_states,
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
return hidden_states
......@@ -151,7 +195,8 @@ class BertPredictionHeadTransform(nn.Module):
raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.transform_act_fn = nn.GELU(approximate='tanh')
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
self.transform_act_fn = nn.GELU(approximate=approximate)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -264,6 +309,11 @@ class BertModel(BertPreTrainedModel):
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
masked_tokens_mask=None):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
"""
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32
......@@ -275,9 +325,38 @@ class BertModel(BertPreTrainedModel):
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
)
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask)
if masked_tokens_mask is not None:
batch_size, seqlen = input_ids.shape[:2]
# We also need the first column for the CLS token
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
device=input_ids.device)
first_col_mask[:, 0] = True
subset_mask = masked_tokens_mask | first_col_mask
else:
subset_mask = None
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
subset_mask=subset_mask)
if masked_tokens_mask is None:
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return sequence_output, pooled_output
else:
# TD [2022-03-01]: the indexing here is very tricky.
if attention_mask is not None:
subset_idx = subset_mask[attention_mask]
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
else:
pool_input = sequence_output[first_col_mask[subset_mask]]
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
pooled_output = (self.pooler(pool_input, pool=False)
if self.pooler is not None else None)
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
)
class BertForPreTraining(BertPreTrainedModel):
......@@ -290,11 +369,13 @@ class BertForPreTraining(BertPreTrainedModel):
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
assert not self.last_layer_subset, 'last_layer_subset is not implemented yet'
if self.last_layer_subset:
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
use_xentropy = getattr(config, 'use_xentropy', False)
if use_xentropy and CrossEntropyLossApex is None:
raise ImportError('xentropy_cuda is not installed')
loss_cls = nn.CrossEntropyLoss if not use_xentropy else CrossEntropyLossApex
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
else partial(CrossEntropyLossApex, inplace_backward=True))
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
......@@ -311,6 +392,8 @@ class BertForPreTraining(BertPreTrainedModel):
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
labels=None, next_sentence_label=None):
"""
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask).
Outputs:
if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
......@@ -322,10 +405,12 @@ class BertForPreTraining(BertPreTrainedModel):
"""
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
sequence_output, pooled_output = self.bert(
outputs = self.bert(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask.bool(), masked_tokens_mask=masked_tokens_mask
attention_mask=attention_mask.bool() if attention_mask is not None else None,
masked_tokens_mask=masked_tokens_mask
)
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
if self.dense_seq_output and labels is not None:
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
if not self.last_layer_subset:
......@@ -333,8 +418,9 @@ class BertForPreTraining(BertPreTrainedModel):
masked_token_idx)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
if masked_token_idx is not None: # prediction_scores are already flattened
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
masked_lm_loss = self.mlm_loss(prediction_scores,
labels.flatten()[masked_token_idx])
else:
......@@ -342,22 +428,13 @@ class BertForPreTraining(BertPreTrainedModel):
rearrange(labels, '... -> (...)'))
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
rearrange(next_sentence_label, '... -> (...)'))
total_loss = (masked_lm_loss + next_sentence_loss).float()
# Masked Language Model Accuracy
masked_lm_labels_flat = labels.view(-1)
mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != 0]
if not self.dense_seq_output:
prediction_scores_flat = rearrange(prediction_scores, '... v -> (...) v')
mlm_predictions_scores = prediction_scores_flat[masked_lm_labels_flat != 0]
mlm_predictions = mlm_predictions_scores.argmax(dim=-1)
else:
mlm_predictions = prediction_scores.argmax(dim=-1)
mlm_acc = (mlm_predictions == mlm_labels).sum(dtype=torch.float) / mlm_labels.numel()
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
return total_loss, prediction_scores, seq_relationship_score, mlm_acc, mlm_labels.numel()
else:
return prediction_scores, seq_relationship_score
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
)
def state_dict_from_pretrained(model_name):
......@@ -401,6 +478,7 @@ def remap_state_dict(state_dict, config):
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
last_layer_subset = getattr(config, 'last_layer_subset', False)
for d in range(config.num_hidden_layers):
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
......@@ -408,12 +486,22 @@ def remap_state_dict(state_dict, config):
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
if not (last_layer_subset and d == config.num_hidden_layers - 1):
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat(
[bq, bk, bv], dim=0
)
else:
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
[Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat(
[bk, bv], dim=0
)
def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
......@@ -423,4 +511,23 @@ def remap_state_dict(state_dict, config):
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
# Word embedding
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if pad_vocab_size_multiple > 1:
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
)
decoder_weight = state_dict['cls.predictions.decoder.weight']
state_dict['cls.predictions.decoder.weight'] = F.pad(
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
)
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias = state_dict['cls.predictions.decoder.bias']
state_dict['cls.predictions.decoder.bias'] = F.pad(
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
)
return state_dict
......@@ -120,15 +120,36 @@ class FlashCrossAttention(nn.Module):
self.dropout_p = attention_dropout
self.triton = triton
def forward(self, q, kv):
def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
assert cu_seqlens_k is not None
assert cu_seqlens_k.dtype == torch.int32
assert max_seqlen_k is not None
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
......@@ -214,12 +235,14 @@ class CrossAttention(nn.Module):
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, kv):
def forward(self, q, kv, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
......@@ -227,6 +250,12 @@ class CrossAttention(nn.Module):
k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
......@@ -295,9 +324,11 @@ class MHA(nn.Module):
groups=3 * embed_dim)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
else:
# TODO: use the residual linear class for Wq
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
else:
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
if self.dwconv:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim)
......@@ -309,7 +340,8 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def forward(self, x, x_kv=None, cu_seqlens=None, max_seqlen=None, key_padding_mask=None):
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
**kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
......@@ -327,17 +359,15 @@ class MHA(nn.Module):
assert max_seqlen is not None
assert key_padding_mask is None
assert self.use_flash_attn
assert not self.cross_attn, ('Unpadded FlashAttention code path for cross-attention'
'is not implemented yet')
assert not self.dwconv
assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.use_flash_attn
assert not self.cross_attn, ('Key padding mask code path for cross-attention'
'is not implemented yet')
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
if not self.cross_attn:
if not self.return_residual:
qkv = self.Wqkv(x)
......@@ -349,24 +379,30 @@ class MHA(nn.Module):
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
extra_kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask})
if not self.checkpointing:
context = self.inner_attn(qkv, **extra_kwargs)
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
if not self.return_residual:
q = self.Wq(x)
kv = self.Wkv(x_kv if x_kv is not None else x)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **extra_kwargs)
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads)
kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d',
two=2, h=self.num_heads)
kv, x = self.Wkv(x)
q = self.Wq(x)
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, h=self.num_heads)
if self.dwconv:
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
if not self.checkpointing:
context = self.inner_attn(q, kv)
context = self.inner_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x)
......@@ -15,20 +15,21 @@ except ImportError:
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
device=None, dtype=None):
return_residual=False, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
y = self.fc1(x)
y = self.activation(y)
y = self.fc2(y)
return y if not self.return_residual else (y, x)
class FusedDenseGeluDense(nn.Module):
......
......@@ -53,15 +53,12 @@ def test_bert_non_optimized(model_name):
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model_hf = get_hf_models(model_name, config, dtype)
model.eval()
model_ref.eval()
......@@ -74,7 +71,8 @@ def test_bert_non_optimized(model_name):
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
......@@ -84,8 +82,8 @@ def test_bert_non_optimized(model_name):
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 2 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 2 * (pooled_output_hf - pooled_output_ref).abs().max().item()
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
......@@ -97,8 +95,9 @@ def test_bert_optimized(model_name):
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
# Our implementation of fused_dense_gelu_dense assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
......@@ -109,7 +108,7 @@ def test_bert_optimized(model_name):
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model_hf = get_hf_models(model_name, config, dtype)
model.eval()
model_ref.eval()
......@@ -122,7 +121,8 @@ def test_bert_optimized(model_name):
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
# Need to zero out the padded tokens in the sequence before comparison.
......@@ -138,7 +138,8 @@ def test_bert_optimized(model_name):
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()
prediction_scores, seq_relationship_scores = model(input_ids, attention_mask=attention_mask)
out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
# Need to zero out the padded tokens in the sequence before comparison.
prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0
......@@ -157,30 +158,36 @@ def test_bert_optimized(model_name):
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
@pytest.mark.parametrize('last_layer_subset', [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name):
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
# Our implementation of fused_dense_gelu_dense assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.dense_seq_output = True
config.last_layer_subset = last_layer_subset
config.use_xentropy = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model_hf = get_hf_models(model_name, config, dtype)
model.eval()
model_ref.eval()
......@@ -190,19 +197,25 @@ def test_bert_dense_seq_output(model_name):
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
else:
attention_mask = None
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels[(torch.rand(batch_size, max_seqlen, device='cuda') < 0.15) | ~attention_mask] = 0
if attention_mask is not None:
labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')
total_loss, prediction_scores, seq_relationship_scores, _, _ = model(
out = model(
input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label
)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
......@@ -217,3 +230,6 @@ def test_bert_dense_seq_output(model_name):
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment