Commit 5fb6df0e authored by Tri Dao's avatar Tri Dao
Browse files

Implement BERT

parent dc24c226
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis from flash_attn.bert_padding import unpad_input, pad_input
class FlashAttention(nn.Module): class FlashAttention(nn.Module):
......
This diff is collapsed.
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers import GPT2Config
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
......
...@@ -23,10 +23,16 @@ class Block(nn.Module): ...@@ -23,10 +23,16 @@ class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
fused_dropout_add_ln=False): fused_dropout_add_ln=False, return_residual=False):
"""
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super().__init__() super().__init__()
self.prenorm = prenorm self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln self.fused_dropout_add_ln = fused_dropout_add_ln
self.return_residual = return_residual
if mixer_cls is None: if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64) mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None: if mlp_cls is None:
...@@ -92,8 +98,11 @@ class Block(nn.Module): ...@@ -92,8 +98,11 @@ class Block(nn.Module):
return hidden_states, residual return hidden_states, residual
else: else:
assert residual is None assert residual is None
mixer_out = self.mixer(hidden_states, mixer_out = self.mixer(
**(mixer_kwargs if mixer_kwargs is not None else {})) hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
)
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
+ hidden_states).to(dtype=self.norm1.weight.dtype)) + hidden_states).to(dtype=self.norm1.weight.dtype))
...@@ -111,6 +120,8 @@ class Block(nn.Module): ...@@ -111,6 +120,8 @@ class Block(nn.Module):
) )
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states) mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
+ hidden_states).to(dtype=self.norm2.weight.dtype)) + hidden_states).to(dtype=self.norm2.weight.dtype))
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import repeat
class GPT2Embeddings(nn.Module): class GPT2Embeddings(nn.Module):
...@@ -21,15 +19,51 @@ class GPT2Embeddings(nn.Module): ...@@ -21,15 +19,51 @@ class GPT2Embeddings(nn.Module):
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None):
""" """
input_ids: (batch, seqlen) input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
return embeddings
class BertEmbeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
padding_idx=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim)
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
""" """
batch_size, seqlen = input_ids.shape batch_size, seqlen = input_ids.shape
input_embeddings = self.word_embeddings(input_ids) embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0: if self.max_position_embeddings > 0:
if position_ids is None: if position_ids is None:
position_ids = repeat(torch.arange(seqlen, dtype=torch.long, position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
device=input_ids.device),
's -> b s', b=batch_size)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
return input_embeddings + position_embeddings embeddings = embeddings + position_embeddings
else: if self.type_vocab_size > 0:
return input_embeddings if token_type_ids is None:
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
...@@ -53,24 +53,45 @@ class FlashSelfAttention(nn.Module): ...@@ -53,24 +53,45 @@ class FlashSelfAttention(nn.Module):
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
self.triton = triton self.triton = triton
def forward(self, qkv): def forward(self, qkv, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
""" """
assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda assert qkv.is_cuda
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
else:
batch_size, seqlen = qkv.shape[0], qkv.shape[1] batch_size, seqlen = qkv.shape[0], qkv.shape[1]
if self.triton and (self.dropout_p == 0 or not self.training): # Triton version doesn't support dropout # Triton version doesn't support dropout
if self.triton and (self.dropout_p == 0 or not self.training):
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale) output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale)
else: else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...') qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen max_seqlen = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device) device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func( output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=self.causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
...@@ -146,16 +167,24 @@ class SelfAttention(nn.Module): ...@@ -146,16 +167,24 @@ class SelfAttention(nn.Module):
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
def forward(self, qkv): def forward(self, qkv, key_padding_mask=None):
"""Implements the multihead softmax attention. """Implements the multihead softmax attention.
Arguments Arguments
--------- ---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
""" """
batch_size, seqlen = qkv.shape[0], qkv.shape[1] batch_size, seqlen = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2) q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal: if self.causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' # "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float # So we have to construct the mask in float
...@@ -239,6 +268,7 @@ class MHA(nn.Module): ...@@ -239,6 +268,7 @@ class MHA(nn.Module):
self.causal = causal self.causal = causal
self.dwconv = dwconv self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.return_residual = return_residual self.return_residual = return_residual
self.checkpointing = checkpointing self.checkpointing = checkpointing
...@@ -279,12 +309,35 @@ class MHA(nn.Module): ...@@ -279,12 +309,35 @@ class MHA(nn.Module):
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def forward(self, x, x_kv=None): def forward(self, x, x_kv=None, cu_seqlens=None, max_seqlen=None, key_padding_mask=None):
""" """
Arguments: Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
""" """
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.use_flash_attn
assert not self.cross_attn, ('Unpadded FlashAttention code path for cross-attention'
'is not implemented yet')
assert not self.dwconv
assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.use_flash_attn
assert not self.cross_attn, ('Key padding mask code path for cross-attention'
'is not implemented yet')
if not self.cross_attn: if not self.cross_attn:
if not self.return_residual: if not self.return_residual:
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
...@@ -293,14 +346,15 @@ class MHA(nn.Module): ...@@ -293,14 +346,15 @@ class MHA(nn.Module):
if self.dwconv: if self.dwconv:
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads)
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv) qkv = self.rotary_emb(qkv)
extra_kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask})
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(qkv) context = self.inner_attn(qkv, **extra_kwargs)
else: else:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv) context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **extra_kwargs)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv)
else: else:
q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads) q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads)
kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d', kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d',
...@@ -313,7 +367,6 @@ class MHA(nn.Module): ...@@ -313,7 +367,6 @@ class MHA(nn.Module):
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(q, kv) context = self.inner_attn(q, kv)
else: else:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv) context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv)
out = self.out_proj(rearrange(context, 'b s h d -> b s (h d)')) out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
...@@ -200,6 +200,10 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -200,6 +200,10 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None, None) None, None)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False): return_dropout_mask=False):
......
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import state_dict_from_pretrained
from flash_attn.models.bert import remap_state_dict
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name)
pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
model = BertForPreTraining(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
return key
pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
for k, v in pretrained_state_dict.items())
model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf.load_state_dict(pretrained_state_dict, strict=False)
model_hf.cuda().to(dtype=dtype)
return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask)
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 2 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 2 * (pooled_output_hf - pooled_output_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
sequence_output, pooled_output = model.bert(input_ids, attention_mask=attention_mask)
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
# Need to zero out the padded tokens in the sequence before comparison.
sequence_output_hf[~attention_mask, :] = 0.0
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0
print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()
prediction_scores, seq_relationship_scores = model(input_ids, attention_mask=attention_mask)
# Need to zero out the padded tokens in the sequence before comparison.
prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref[~attention_mask, :] = 0.0
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.dense_seq_output = True
config.use_xentropy = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, torch.float16)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
labels[(torch.rand(batch_size, max_seqlen, device='cuda') < 0.15) | ~attention_mask] = 0
masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')
total_loss, prediction_scores, seq_relationship_scores, _, _ = model(
input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label
)
out_hf = model_hf(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask]
out_ref = model_ref(input_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sequence_label)
prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask]
print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment