test_bert.py 13 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
import re
from collections import OrderedDict

Tri Dao's avatar
Tri Dao committed
4
import pytest
Tri Dao's avatar
Tri Dao committed
5
6
7
import torch
import torch.nn.functional as F
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
8
9
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
Tri Dao's avatar
Tri Dao committed
10
11
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
Tri Dao's avatar
Tri Dao committed
12
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
Tri Dao's avatar
Tri Dao committed
13
14


Tri Dao's avatar
Tri Dao committed
15
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
23
24
25
26
27
28
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
    config = BertConfig.from_pretrained(model_name)
    pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
    model = BertForPreTraining(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


def get_hf_models(model_name, config, dtype):
    pretrained_state_dict = state_dict_from_pretrained(model_name)
Tri Dao's avatar
Tri Dao committed
29

Tri Dao's avatar
Tri Dao committed
30
    def key_mapping_ln_gamma_beta(key):
Tri Dao's avatar
Tri Dao committed
31
32
        key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
        key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
Tri Dao's avatar
Tri Dao committed
33
        return key
Tri Dao's avatar
Tri Dao committed
34
35
36
37

    pretrained_state_dict = OrderedDict(
        (key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
    )
Tri Dao's avatar
Tri Dao committed
38
39
40
41
42
43
44
45
    model_hf = BertForPreTrainingHF(config)
    # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
    # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
    model_hf.load_state_dict(pretrained_state_dict, strict=False)
    model_hf.cuda().to(dtype=dtype)
    return model_hf


46
@pytest.mark.parametrize('model_name', ["bert-base-uncased"])
Tri Dao's avatar
Tri Dao committed
47
48
49
50
51
52
53
54
55
56
57
58
def test_bert_non_optimized(model_name):
    """Check that our implementation of BERT (without any optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
59
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
60
61
62
63
64
65
66
67

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
73
74
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
75
76
77
78
79
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output

Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
85
86
87
88
89
    print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
    print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
    assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
        sequence_output_hf - sequence_output_ref
    ).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
        pooled_output_hf - pooled_output_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
90
91


Tri Dao's avatar
Tri Dao committed
92
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
101
    # Our implementation of fused_mlp assumes the activation is
102
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
103
    # If you just want "gelu", disable fused_mlp.
Tri Dao's avatar
Tri Dao committed
104
105
106
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
107
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
    config.fused_dropout_add_ln = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
114
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
115
116
117
118
119
120
121
122

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
123
124
125
126
127
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
128
129
    out = model.bert(input_ids, attention_mask=attention_mask)
    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
Tri Dao's avatar
Tri Dao committed
130
131
132
133
134
135
136
137
    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
    # Need to zero out the padded tokens in the sequence before comparison.
    sequence_output_hf[~attention_mask, :] = 0.0
    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
    sequence_output_ref[~attention_mask, :] = 0.0

Tri Dao's avatar
Tri Dao committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    print(
        f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
    )
    print(
        f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
    )
    assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
        sequence_output_hf - sequence_output_ref
    ).abs().max().item()
    assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
        pooled_output_hf - pooled_output_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
156

157
158
    out = model(input_ids, attention_mask=attention_mask)
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
159
160
161
162
    # Need to zero out the padded tokens in the sequence before comparison.
    prediction_scores = prediction_scores.clone()
    prediction_scores[~attention_mask, :] = 0.0
    out_hf = model_hf(input_ids, attention_mask=attention_mask)
Tri Dao's avatar
Tri Dao committed
163
164
165
166
    prediction_scores_hf, seq_relationship_scores_hf = (
        out_hf.prediction_logits,
        out_hf.seq_relationship_logits,
    )
Tri Dao's avatar
Tri Dao committed
167
168
    prediction_scores_hf[~attention_mask, :] = 0.0
    out_ref = model_ref(input_ids, attention_mask=attention_mask)
Tri Dao's avatar
Tri Dao committed
169
170
171
172
    prediction_scores_ref, seq_relationship_scores_ref = (
        out_ref.prediction_logits,
        out_ref.seq_relationship_logits,
    )
Tri Dao's avatar
Tri Dao committed
173
174
    prediction_scores_ref[~attention_mask, :] = 0.0

Tri Dao's avatar
Tri Dao committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    print(
        f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
    )
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
        prediction_scores_hf - prediction_scores_ref
    ).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
        seq_relationship_scores_hf - seq_relationship_scores_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
193
194


Tri Dao's avatar
Tri Dao committed
195
@pytest.mark.parametrize("last_layer_subset", [False, True])
196
# @pytest.mark.parametrize('last_layer_subset', [True])
Tri Dao's avatar
Tri Dao committed
197
@pytest.mark.parametrize("has_key_padding_mask", [True, False])
198
# @pytest.mark.parametrize('has_key_padding_mask', [True])
Tri Dao's avatar
Tri Dao committed
199
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
Tri Dao's avatar
Tri Dao committed
200
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
201
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
Tri Dao's avatar
Tri Dao committed
202
203
204
205
206
207
    """Check that our implementation of BERT (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = BertConfig.from_pretrained(model_name)
208
    # Our implementation of fused_mlp assumes the activation is
209
    # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
210
    # If you just want "gelu", disable fused_mlp.
Tri Dao's avatar
Tri Dao committed
211
212
213
    config.hidden_act = "gelu_new"
    config.use_flash_attn = True
    config.fused_bias_fc = True
214
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
215
216
    config.fused_dropout_add_ln = True
    config.dense_seq_output = True
217
    config.last_layer_subset = last_layer_subset
Tri Dao's avatar
Tri Dao committed
218
219
220
221
222
223
    config.use_xentropy = True

    model = BertForPreTraining.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = get_hf_models(model_name, config, torch.float32)
224
    model_hf = get_hf_models(model_name, config, dtype)
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231
232

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
233
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
234
    if has_key_padding_mask:
Tri Dao's avatar
Tri Dao committed
235
        attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
236
237
    else:
        attention_mask = None
Tri Dao's avatar
Tri Dao committed
238
239
240
241
242
243
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
    labels = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
244
245
    if attention_mask is not None:
        labels[~attention_mask] = 0
Tri Dao's avatar
Tri Dao committed
246
    labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
Tri Dao's avatar
Tri Dao committed
247
    masked_tokens_mask = labels.flatten() > 0
Tri Dao's avatar
Tri Dao committed
248
    next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
Tri Dao's avatar
Tri Dao committed
249

250
    out = model(
Tri Dao's avatar
Tri Dao committed
251
252
253
254
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
Tri Dao's avatar
Tri Dao committed
255
    )
256
    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
Tri Dao's avatar
Tri Dao committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    out_hf = model_hf(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
    )
    prediction_scores_hf, seq_relationship_scores_hf = (
        out_hf.prediction_logits,
        out_hf.seq_relationship_logits,
    )
    prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
    out_ref = model_ref(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        next_sentence_label=next_sequence_label,
    )
    prediction_scores_ref, seq_relationship_scores_ref = (
        out_ref.prediction_logits,
        out_ref.seq_relationship_logits,
    )
    prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]

    print(
        f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
    )
    print(
        f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
    )
    print(
        f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
    )
    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
        prediction_scores_hf - prediction_scores_ref
    ).abs().max().item()
    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
        seq_relationship_scores_hf - seq_relationship_scores_ref
    ).abs().max().item()
298
299
    # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
    # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()