test_gpt_generation.py 11 KB
Newer Older
1
import os
Tri Dao's avatar
Tri Dao committed
2
import re
Tri Dao's avatar
Tri Dao committed
3
import time
Tri Dao's avatar
Tri Dao committed
4
5

import pytest
Tri Dao's avatar
Tri Dao committed
6
import torch
Tri Dao's avatar
Tri Dao committed
7
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
8
9
10
11
12
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPT2Config, GPT2Tokenizer, OPTConfig
Tri Dao's avatar
Tri Dao committed
13
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
Tri Dao's avatar
Tri Dao committed
14
from transformers.models.opt.modeling_opt import OPTForCausalLM
Tri Dao's avatar
Tri Dao committed
15
16


Tri Dao's avatar
Tri Dao committed
17
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
18
# @pytest.mark.parametrize('fused_ft_kernel', [True])
Tri Dao's avatar
Tri Dao committed
19
@pytest.mark.parametrize("optimized", [False, True])
20
# @pytest.mark.parametrize('optimized', [False])
Tri Dao's avatar
Tri Dao committed
21
@pytest.mark.parametrize("rotary", [False, True])
22
# @pytest.mark.parametrize('rotary', [False])
Tri Dao's avatar
Tri Dao committed
23
@pytest.mark.parametrize("model_name", ["gpt2"])
Tri Dao's avatar
Tri Dao committed
24
def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
Tri Dao's avatar
Tri Dao committed
25
26
27
28
29
    """Check that our implementation of GPT2 generation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
30
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
31
32
    rtol, atol = 3e-3, 3e-1
    config = GPT2Config.from_pretrained(model_name)
33
34
    if rotary:
        config.n_positions = 0
35
36
        config.rotary_emb_fraction = 0.5
        config.rotary_emb_base = 24000
37
    config.residual_in_fp32 = True
Tri Dao's avatar
Tri Dao committed
38
39
40
    if optimized:
        config.use_flash_attn = True
        config.fused_bias_fc = True
41
        config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
42
43
        config.fused_dropout_add_ln = True

44
45
    # if not rotary, we load the weight from HF but ignore the position embeddings.
    # The model would be nonsense but it doesn't matter for the test.
Tri Dao's avatar
Tri Dao committed
46
47
48
    model = GPTLMHeadModel.from_pretrained(
        model_name, config, strict=not rotary, device=device, dtype=dtype
    )
Tri Dao's avatar
Tri Dao committed
49
    model.eval()
50
51

    if not rotary:
52
        model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
Tri Dao's avatar
Tri Dao committed
53
54
55
        model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
            device=device
        )
56
57
        model_ref.eval()
        model_hf.eval()
Tri Dao's avatar
Tri Dao committed
58
59
60

    torch.manual_seed(0)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
Tri Dao's avatar
Tri Dao committed
61
62
63
    input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
        device=device
    )
Tri Dao's avatar
Tri Dao committed
64
    max_length = 25
65
    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
Tri Dao's avatar
Tri Dao committed
66
    # max_length = input_ids.shape[1] + 40
Tri Dao's avatar
Tri Dao committed
67
68
69
70
71
72
73
74
75

    # Slow generation for reference
    sequences = []
    scores = []
    cur_input_ids = input_ids
    with torch.inference_mode():
        scores.append(model(cur_input_ids).logits[:, -1])
        sequences.append(scores[-1].argmax(dim=-1))
        for _ in range(input_ids.shape[1] + 1, max_length):
Tri Dao's avatar
Tri Dao committed
76
            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
            scores.append(model(cur_input_ids).logits[:, -1])
            sequences.append(scores[-1].argmax(dim=-1))
    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
    scores = tuple(scores)

Tri Dao's avatar
Tri Dao committed
82
83
84
85
86
87
88
89
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        fused_ft_kernel=fused_ft_kernel,
        return_dict_in_generate=True,
        output_scores=True,
        timing=True,
    )
Tri Dao's avatar
Tri Dao committed
90
    print(out.sequences)
91
    print(tokenizer.batch_decode(out.sequences.tolist()))
Tri Dao's avatar
Tri Dao committed
92
    if fused_ft_kernel:
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100
101
        out_cg = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            fused_ft_kernel=fused_ft_kernel,
            cg=True,
            return_dict_in_generate=True,
            output_scores=True,
            timing=True,
        )
Tri Dao's avatar
Tri Dao committed
102
        print(out_cg.sequences)
Tri Dao's avatar
Tri Dao committed
103

104
    if not rotary:
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        out_hf = model_hf.generate(
            input_ids=input_ids,
            max_length=max_length,
            return_dict_in_generate=True,
            output_scores=True,
        )
        out_ref = model_ref.generate(
            input_ids=input_ids,
            max_length=max_length,
            return_dict_in_generate=True,
            output_scores=True,
        )

        print(
            f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
        print(
            f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
130
        print(tokenizer.batch_decode(out_ref.sequences.tolist()))
Tri Dao's avatar
Tri Dao committed
131
132

    assert torch.all(out.sequences == sequences)
Tri Dao's avatar
Tri Dao committed
133
134
135
    assert torch.allclose(
        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
    )
136
137
138
    if not rotary:
        assert torch.all(out.sequences == out_ref.sequences)
        assert torch.all(out.sequences == out_hf.sequences)
Tri Dao's avatar
Tri Dao committed
139

Tri Dao's avatar
Tri Dao committed
140
141
142
143
144
        assert (
            torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
        ).abs().max().item() < 3 * (
            torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
        ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
145
146


Tri Dao's avatar
Tri Dao committed
147
148
149
150
151
152
153
154
155
156
@pytest.mark.parametrize(
    "model_name",
    [
        "facebook/opt-125m",
        "facebook/opt-350m",
        "facebook/opt-1.3b",
        "facebook/opt-2.7b",
        "facebook/opt-6.7b",
    ],
)
157
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
Tri Dao's avatar
Tri Dao committed
158
159
160
161
162
def test_greedy_decode_opt(model_name):
    """Check that our implementation of OPT generation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
Tri Dao's avatar
Tri Dao committed
163
    print(f"\nMODEL: {model_name}")
Tri Dao's avatar
Tri Dao committed
164
165
    verbose = False
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
166
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
167
168
169
170
    rtol, atol = 3e-3, 3e-1
    fused_ft_kernel = True
    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
    # Only prenorm supports residual_in_fp32
Tri Dao's avatar
Tri Dao committed
171
    config.residual_in_fp32 = getattr(config, "prenorm", True)
Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = True
    config.fused_dropout_add_ln = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    # OPT tokenizer requires use_fast=False
    # https://huggingface.co/docs/transformers/model_doc/opt
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    eos_token_id = tokenizer.eos_token_id

Tri Dao's avatar
Tri Dao committed
186
187
188
    input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
        device=device
    )
Tri Dao's avatar
Tri Dao committed
189
    max_length = 25
Tri Dao's avatar
Tri Dao committed
190
191
192
193
194
195
196
197
198
199
200
    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
    # max_length = input_ids.shape[1] + 40

    # Slow generation for reference
    sequences = []
    scores = []
    cur_input_ids = input_ids
    with torch.inference_mode():
        scores.append(model(cur_input_ids).logits[:, -1])
        sequences.append(scores[-1].argmax(dim=-1))
        for _ in range(input_ids.shape[1] + 1, max_length):
Tri Dao's avatar
Tri Dao committed
201
            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
Tri Dao's avatar
Tri Dao committed
202
203
204
205
206
207
208
            scores.append(model(cur_input_ids).logits[:, -1])
            sequences.append(scores[-1].argmax(dim=-1))
            if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
                break
    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
    scores = tuple(scores)

Tri Dao's avatar
Tri Dao committed
209
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
210
211
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
218
219
220
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        fused_ft_kernel=fused_ft_kernel,
        return_dict_in_generate=True,
        output_scores=True,
        timing=True,
    )
Tri Dao's avatar
Tri Dao committed
221
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
222
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
223
224
225
226
227
228
    if verbose:
        print(out.sequences)
    print(tokenizer.batch_decode(out.sequences.tolist()))
    if fused_ft_kernel:
        # Capture graph outside the timing loop
        batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
229
230
        model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
        print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
231
232
        torch.cuda.synchronize()
        start = time.time()
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
238
239
240
241
        out_cg = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            fused_ft_kernel=fused_ft_kernel,
            cg=True,
            return_dict_in_generate=True,
            output_scores=True,
            timing=True,
        )
Tri Dao's avatar
Tri Dao committed
242
        torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
243
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
244
245
        if verbose:
            print(out_cg.sequences)
246
        print(tokenizer.batch_decode(out_cg.sequences.tolist()))
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
253
254

    del model

    model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
255
256
257
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
258
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
259
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
260
261
262
263
264
265
266
    del model_hf

    model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
    model_ref.eval()
    print("HF fp32")
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
267
268
269
    out_ref = model_ref.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
270
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
271
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
272
273
274
275
    del model_ref
    print(tokenizer.batch_decode(out_ref.sequences.tolist()))

    if verbose:
Tri Dao's avatar
Tri Dao committed
276
277
278
279
280
281
282
283
284
285
286
287
        print(
            f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
        print(
            f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
Tri Dao's avatar
Tri Dao committed
288
289

    assert torch.all(out.sequences == sequences)
Tri Dao's avatar
Tri Dao committed
290
291
292
    assert torch.allclose(
        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
    )
Tri Dao's avatar
Tri Dao committed
293
294
295
    assert torch.all(out.sequences == out_ref.sequences)
    assert torch.all(out.sequences == out_hf.sequences)

Tri Dao's avatar
Tri Dao committed
296
297
298
    assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
        torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
    ).abs().max().item()