test_bert.py 12.8 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import re
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytest

from einops import rearrange

from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF

from flash_attn.models.bert import BertModel, BertForPreTraining
from flash_attn.models.bert import remap_state_dict
16
from flash_attn.utils.pretrained import state_dict_from_pretrained
Tri Dao's avatar
Tri Dao committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
    config = BertConfig.from_pretrained(model_name)
    pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
    model = BertForPreTraining(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


def get_hf_models(model_name, config, dtype):
    pretrained_state_dict = state_dict_from_pretrained(model_name)
    def key_mapping_ln_gamma_beta(key):
        key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
        key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
        return key
    pretrained_state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v)
                                        for k, v in pretrained_state_dict.items())
    model_hf = BertForPreTrainingHF(config)
    # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
    # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
    model_hf.load_state_dict(pretrained_state_dict, strict=False)
    model_hf.cuda().to(dtype=dtype)
    return model_hf


@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
    """Check that our implementation of BERT (without any optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
61
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
67
68
69
70
71
72
73

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
    attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device='cuda')
74
75
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
83
84
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output

    print(f'Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
    print(f'Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
85
86
    assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (sequence_output_hf - sequence_output_ref).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (pooled_output_hf - pooled_output_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
93
94
95
96
97


@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
98
99
100
    # Our implementation of fused_dense_gelu_dense assumes the activation is
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
    # If you just want "gelu", disable fused_dense_gelu_dense.
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
106
107
108
109
110
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_dense_gelu_dense = True
    config.fused_dropout_add_ln = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
111
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
112
113
114
115
116
117
118
119
120
121
122
123

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
    attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device='cuda')
124
125
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    # Need to zero out the padded tokens in the sequence before comparison.
    sequence_output_hf[~attention_mask, :] = 0.0
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
    sequence_output_ref[~attention_mask, :] = 0.0

    print(f'BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}')
    print(f'BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}')
    print(f'HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}')
    print(f'HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}')
    assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (sequence_output_hf - sequence_output_ref).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (pooled_output_hf - pooled_output_ref).abs().max().item()

141
142
    out = model(input_ids, attention_mask=attention_mask)
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    # Need to zero out the padded tokens in the sequence before comparison.
    prediction_scores = prediction_scores.clone()
    prediction_scores[~attention_mask, :] = 0.0
    out_hf = model_hf(input_ids, attention_mask=attention_mask)
    prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
    prediction_scores_hf[~attention_mask, :] = 0.0
    out_ref = model_ref(input_ids, attention_mask=attention_mask)
    prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
    prediction_scores_ref[~attention_mask, :] = 0.0

    print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
    print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
    print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
    print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()


161
162
163
164
@pytest.mark.parametrize('last_layer_subset', [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize('has_key_padding_mask', [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
Tri Dao's avatar
Tri Dao committed
165
166
@pytest.mark.parametrize('model_name', ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
167
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
Tri Dao's avatar
Tri Dao committed
168
169
170
171
172
173
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
174
175
176
    # Our implementation of fused_dense_gelu_dense assumes the activation is
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
    # If you just want "gelu", disable fused_dense_gelu_dense.
Tri Dao's avatar
Tri Dao committed
177
178
179
180
181
182
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_dense_gelu_dense = True
    config.fused_dropout_add_ln = True
    config.dense_seq_output = True
183
    config.last_layer_subset = last_layer_subset
Tri Dao's avatar
Tri Dao committed
184
185
186
187
188
189
    config.use_xentropy = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
190
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
191
192
193
194
195
196
197
198
199

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
200
201
202
203
    if has_key_padding_mask:
        attention_mask = torch.arange(max_seqlen, device='cuda')[None, :] < seqlens[:, None]
    else:
        attention_mask = None
Tri Dao's avatar
Tri Dao committed
204
205
206
207
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device='cuda')
    labels = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                           device='cuda')
208
209
210
    if attention_mask is not None:
        labels[~attention_mask] = 0
    labels[(torch.rand(batch_size, max_seqlen, device='cuda') > 0.15)] = 0
Tri Dao's avatar
Tri Dao committed
211
212
213
    masked_tokens_mask = labels.flatten() > 0
    next_sequence_label = torch.randint(0, 2, (batch_size,), device='cuda')

214
    out = model(
Tri Dao's avatar
Tri Dao committed
215
216
217
        input_ids, attention_mask=attention_mask,
        labels=labels, next_sentence_label=next_sequence_label
    )
218
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    out_hf = model_hf(input_ids, attention_mask=attention_mask,
                      labels=labels, next_sentence_label=next_sequence_label)
    prediction_scores_hf, seq_relationship_scores_hf = out_hf.prediction_logits, out_hf.seq_relationship_logits
    prediction_scores_hf = rearrange(prediction_scores_hf, 'b s d -> (b s) d')[masked_tokens_mask]
    out_ref = model_ref(input_ids, attention_mask=attention_mask,
                        labels=labels, next_sentence_label=next_sequence_label)
    prediction_scores_ref, seq_relationship_scores_ref = out_ref.prediction_logits, out_ref.seq_relationship_logits
    prediction_scores_ref = rearrange(prediction_scores_ref, 'b s d -> (b s) d')[masked_tokens_mask]

    print(f'prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}')
    print(f'prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}')
    print(f'HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}')
    print(f'HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}')
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (prediction_scores_hf - prediction_scores_ref).abs().max().item()
233
234
235
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (seq_relationship_scores_hf - seq_relationship_scores_ref).abs().max().item()
    # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
    # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()