test_bert.py 13.1 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
import re
from collections import OrderedDict

Tri Dao's avatar
Tri Dao committed
4
import pytest
Tri Dao's avatar
Tri Dao committed
5
6
7
import torch
import torch.nn.functional as F
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
8
9
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
Tri Dao's avatar
Tri Dao committed
10
11
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
Tri Dao's avatar
Tri Dao committed
12
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
Tri Dao's avatar
Tri Dao committed
13
14


Tri Dao's avatar
Tri Dao committed
15
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
23
24
25
26
27
28
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
    config = BertConfig.from_pretrained(model_name)
    pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
    model = BertForPreTraining(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


def get_hf_models(model_name, config, dtype):
    pretrained_state_dict = state_dict_from_pretrained(model_name)
Tri Dao's avatar
Tri Dao committed
29

Tri Dao's avatar
Tri Dao committed
30
    def key_mapping_ln_gamma_beta(key):
Tri Dao's avatar
Tri Dao committed
31
32
        key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
        key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
Tri Dao's avatar
Tri Dao committed
33
        return key
Tri Dao's avatar
Tri Dao committed
34
35
36
37

    pretrained_state_dict = OrderedDict(
        (key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
    )
Tri Dao's avatar
Tri Dao committed
38
39
40
41
42
43
44
45
    model_hf = BertForPreTrainingHF(config)
    # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
    # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
    model_hf.load_state_dict(pretrained_state_dict, strict=False)
    model_hf.cuda().to(dtype=dtype)
    return model_hf


Tri Dao's avatar
Tri Dao committed
46
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
47
48
49
50
51
52
53
54
55
56
57
58
59
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
    """Check that our implementation of BERT (without any optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
60
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
66
67
68

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
74
75
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output

Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
88
89
90
    print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
    print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
    assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
        sequence_output_hf - sequence_output_ref
    ).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
        pooled_output_hf - pooled_output_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
91
92


Tri Dao's avatar
Tri Dao committed
93
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
100
101
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
102
    # Our implementation of fused_mlp assumes the activation is
103
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
104
    # If you just want "gelu", disable fused_mlp.
Tri Dao's avatar
Tri Dao committed
105
106
107
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
108
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
109
110
111
112
113
114
    config.fused_dropout_add_ln = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
115
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
116
117
118
119
120
121
122
123

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
129
130
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
138
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    # Need to zero out the padded tokens in the sequence before comparison.
    sequence_output_hf[~attention_mask, :] = 0.0
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
    sequence_output_ref[~attention_mask, :] = 0.0

Tri Dao's avatar
Tri Dao committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    print(
        f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
    )
    print(
        f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
    )
    assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
        sequence_output_hf - sequence_output_ref
    ).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
        pooled_output_hf - pooled_output_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
157

158
159
    out = model(input_ids, attention_mask=attention_mask)
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
160
161
162
163
    # Need to zero out the padded tokens in the sequence before comparison.
    prediction_scores = prediction_scores.clone()
    prediction_scores[~attention_mask, :] = 0.0
    out_hf = model_hf(input_ids, attention_mask=attention_mask)
Tri Dao's avatar
Tri Dao committed
164
165
166
167
    prediction_scores_hf, seq_relationship_scores_hf = (
        out_hf.prediction_logits,
        out_hf.seq_relationship_logits,
    )
Tri Dao's avatar
Tri Dao committed
168
169
    prediction_scores_hf[~attention_mask, :] = 0.0
    out_ref = model_ref(input_ids, attention_mask=attention_mask)
Tri Dao's avatar
Tri Dao committed
170
171
172
173
    prediction_scores_ref, seq_relationship_scores_ref = (
        out_ref.prediction_logits,
        out_ref.seq_relationship_logits,
    )
Tri Dao's avatar
Tri Dao committed
174
175
    prediction_scores_ref[~attention_mask, :] = 0.0

Tri Dao's avatar
Tri Dao committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    print(
        f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
    )
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
        prediction_scores_hf - prediction_scores_ref
    ).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
        seq_relationship_scores_hf - seq_relationship_scores_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
194
195


Tri Dao's avatar
Tri Dao committed
196
@pytest.mark.parametrize("last_layer_subset", [False, True])
197
# @pytest.mark.parametrize('last_layer_subset', [True])
Tri Dao's avatar
Tri Dao committed
198
@pytest.mark.parametrize("has_key_padding_mask", [True, False])
199
# @pytest.mark.parametrize('has_key_padding_mask', [True])
Tri Dao's avatar
Tri Dao committed
200
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
201
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
202
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
Tri Dao's avatar
Tri Dao committed
203
204
205
206
207
208
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
209
    # Our implementation of fused_mlp assumes the activation is
210
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
211
    # If you just want "gelu", disable fused_mlp.
Tri Dao's avatar
Tri Dao committed
212
213
214
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
215
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
216
217
    config.fused_dropout_add_ln = True
    config.dense_seq_output = True
218
    config.last_layer_subset = last_layer_subset
Tri Dao's avatar
Tri Dao committed
219
220
221
222
223
224
    config.use_xentropy = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
225
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
226
227
228
229
230
231
232
233

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
234
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
235
    if has_key_padding_mask:
Tri Dao's avatar
Tri Dao committed
236
        attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
237
238
    else:
        attention_mask = None
Tri Dao's avatar
Tri Dao committed
239
240
241
242
243
244
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
    labels = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
245
246
    if attention_mask is not None:
        labels[~attention_mask] = 0
Tri Dao's avatar
Tri Dao committed
247
    labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
Tri Dao's avatar
Tri Dao committed
248
    masked_tokens_mask = labels.flatten() > 0
Tri Dao's avatar
Tri Dao committed
249
    next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
Tri Dao's avatar
Tri Dao committed
250

251
    out = model(
Tri Dao's avatar
Tri Dao committed
252
253
254
255
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
Tri Dao's avatar
Tri Dao committed
256
    )
257
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    out_hf = model_hf(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
    )
    prediction_scores_hf, seq_relationship_scores_hf = (
        out_hf.prediction_logits,
        out_hf.seq_relationship_logits,
    )
    prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
    out_ref = model_ref(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
    )
    prediction_scores_ref, seq_relationship_scores_ref = (
        out_ref.prediction_logits,
        out_ref.seq_relationship_logits,
    )
    prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]

    print(
        f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
    )
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
        prediction_scores_hf - prediction_scores_ref
    ).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
        seq_relationship_scores_hf - seq_relationship_scores_ref
    ).abs().max().item()
299
300
    # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
    # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()