test_bert.py 13.8 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
import re
from collections import OrderedDict

Tri Dao's avatar
Tri Dao committed
4
import pytest
Tri Dao's avatar
Tri Dao committed
5
6
7
8
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import BertConfig
Kevin Hu's avatar
Kevin Hu committed
9
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
Tri Dao's avatar
Tri Dao committed
10
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
Tri Dao's avatar
Tri Dao committed
11

Kevin Hu's avatar
Kevin Hu committed
12
13
14
15
16
17
from flash_attn.models.bert import (
    BertForPreTraining,
    BertModel,
    inv_remap_state_dict,
    remap_state_dict,
)
Kevin Hu's avatar
Kevin Hu committed
18
19
from flash_attn.utils.pretrained import state_dict_from_pretrained

Tri Dao's avatar
Tri Dao committed
20

Tri Dao's avatar
Tri Dao committed
21
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
22
23
24
25
26
27
28
29
30
31
32
33
34
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
    config = BertConfig.from_pretrained(model_name)
    pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
    model = BertForPreTraining(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


def get_hf_models(model_name, config, dtype):
    pretrained_state_dict = state_dict_from_pretrained(model_name)
Tri Dao's avatar
Tri Dao committed
35

Tri Dao's avatar
Tri Dao committed
36
    def key_mapping_ln_gamma_beta(key):
Tri Dao's avatar
Tri Dao committed
37
38
        key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
        key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
Tri Dao's avatar
Tri Dao committed
39
        return key
Tri Dao's avatar
Tri Dao committed
40
41
42
43

    pretrained_state_dict = OrderedDict(
        (key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
    )
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
    model_hf = BertForPreTrainingHF(config)
    # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
    # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
    model_hf.load_state_dict(pretrained_state_dict, strict=False)
    model_hf.cuda().to(dtype=dtype)
    return model_hf


Kevin Hu's avatar
Kevin Hu committed
52
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
Tri Dao's avatar
Tri Dao committed
53
54
55
56
57
58
59
60
61
62
63
64
def test_bert_non_optimized(model_name):
    """Check that our implementation of BERT (without any optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
65
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
71
72
73

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
79
80
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output

Tri Dao's avatar
Tri Dao committed
86
87
88
89
90
91
92
93
94
95
    print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
    print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
    assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
        sequence_output_hf - sequence_output_ref
    ).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
        pooled_output_hf - pooled_output_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
96
97


Tri Dao's avatar
Tri Dao committed
98
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
99
100
101
102
103
104
105
106
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
107
    # Our implementation of fused_mlp assumes the activation is
Kevin Hu's avatar
Kevin Hu committed
108
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
109
    # If you just want "gelu", disable fused_mlp.
Tri Dao's avatar
Tri Dao committed
110
111
112
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
113
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
114
115
116
117
118
119
    config.fused_dropout_add_ln = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
120
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
128

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
134
135
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
136
137
138
139
140
141
142
143
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    # Need to zero out the padded tokens in the sequence before comparison.
    sequence_output_hf[~attention_mask, :] = 0.0
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
    sequence_output_ref[~attention_mask, :] = 0.0

Tri Dao's avatar
Tri Dao committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    print(
        f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
    )
    print(
        f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
    )
    assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
        sequence_output_hf - sequence_output_ref
    ).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
        pooled_output_hf - pooled_output_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
162

163
164
    out = model(input_ids, attention_mask=attention_mask)
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
165
166
167
168
    # Need to zero out the padded tokens in the sequence before comparison.
    prediction_scores = prediction_scores.clone()
    prediction_scores[~attention_mask, :] = 0.0
    out_hf = model_hf(input_ids, attention_mask=attention_mask)
Tri Dao's avatar
Tri Dao committed
169
170
171
172
    prediction_scores_hf, seq_relationship_scores_hf = (
        out_hf.prediction_logits,
        out_hf.seq_relationship_logits,
    )
Tri Dao's avatar
Tri Dao committed
173
174
    prediction_scores_hf[~attention_mask, :] = 0.0
    out_ref = model_ref(input_ids, attention_mask=attention_mask)
Tri Dao's avatar
Tri Dao committed
175
176
177
178
    prediction_scores_ref, seq_relationship_scores_ref = (
        out_ref.prediction_logits,
        out_ref.seq_relationship_logits,
    )
Tri Dao's avatar
Tri Dao committed
179
180
    prediction_scores_ref[~attention_mask, :] = 0.0

Tri Dao's avatar
Tri Dao committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    print(
        f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
    )
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
        prediction_scores_hf - prediction_scores_ref
    ).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
        seq_relationship_scores_hf - seq_relationship_scores_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
199
200


Tri Dao's avatar
Tri Dao committed
201
@pytest.mark.parametrize("last_layer_subset", [False, True])
202
# @pytest.mark.parametrize('last_layer_subset', [True])
Tri Dao's avatar
Tri Dao committed
203
@pytest.mark.parametrize("has_key_padding_mask", [True, False])
204
# @pytest.mark.parametrize('has_key_padding_mask', [True])
Tri Dao's avatar
Tri Dao committed
205
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
206
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
207
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
Tri Dao's avatar
Tri Dao committed
208
209
210
211
212
213
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
214
    # Our implementation of fused_mlp assumes the activation is
Kevin Hu's avatar
Kevin Hu committed
215
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
216
    # If you just want "gelu", disable fused_mlp.
Tri Dao's avatar
Tri Dao committed
217
218
219
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
220
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
221
222
    config.fused_dropout_add_ln = True
    config.dense_seq_output = True
223
    config.last_layer_subset = last_layer_subset
Tri Dao's avatar
Tri Dao committed
224
225
226
227
228
229
    config.use_xentropy = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
230
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
236
237
238

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
239
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
240
    if has_key_padding_mask:
Tri Dao's avatar
Tri Dao committed
241
        attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
242
243
    else:
        attention_mask = None
Tri Dao's avatar
Tri Dao committed
244
245
246
247
248
249
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
    labels = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
250
251
    if attention_mask is not None:
        labels[~attention_mask] = 0
Tri Dao's avatar
Tri Dao committed
252
    labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
Tri Dao's avatar
Tri Dao committed
253
    masked_tokens_mask = labels.flatten() > 0
Tri Dao's avatar
Tri Dao committed
254
    next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
Tri Dao's avatar
Tri Dao committed
255

256
    out = model(
Tri Dao's avatar
Tri Dao committed
257
258
259
260
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
Tri Dao's avatar
Tri Dao committed
261
    )
262
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    out_hf = model_hf(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
    )
    prediction_scores_hf, seq_relationship_scores_hf = (
        out_hf.prediction_logits,
        out_hf.seq_relationship_logits,
    )
    prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
    out_ref = model_ref(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
    )
    prediction_scores_ref, seq_relationship_scores_ref = (
        out_ref.prediction_logits,
        out_ref.seq_relationship_logits,
    )
    prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]

    print(
        f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
    )
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
        prediction_scores_hf - prediction_scores_ref
    ).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
        seq_relationship_scores_hf - seq_relationship_scores_ref
    ).abs().max().item()
304
305
    # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
    # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
Kevin Hu's avatar
Kevin Hu committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324


@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
def test_inv_remap_state_dict(model_name: str):
    """
    Verify that we can convert a HF BERT model to flash_attn and back.
    """

    state_dict = state_dict_from_pretrained(model_name)
    config = BertConfig.from_pretrained(model_name)

    flash_state_dict = remap_state_dict(state_dict, config)
    recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)

    assert set(state_dict.keys()) == set(recovered_state_dict.keys())

    for k in state_dict.keys():
        assert state_dict[k].shape == recovered_state_dict[k].shape
        torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)