test_gpt_generation.py 10.8 KB
Newer Older
1
import os
Tri Dao's avatar
Tri Dao committed
2
import re
Tri Dao's avatar
Tri Dao committed
3
import time
Tri Dao's avatar
Tri Dao committed
4
5
6
7
8
9

import torch
import pytest

from einops import rearrange

Tri Dao's avatar
Tri Dao committed
10
from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer
Tri Dao's avatar
Tri Dao committed
11
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
Tri Dao's avatar
Tri Dao committed
12
from transformers.models.opt.modeling_opt import OPTForCausalLM
Tri Dao's avatar
Tri Dao committed
13
14

from flash_attn.models.gpt import GPTLMHeadModel
Tri Dao's avatar
Tri Dao committed
15
16
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
Tri Dao's avatar
Tri Dao committed
17
from flash_attn.utils.pretrained import state_dict_from_pretrained
18
from flash_attn.utils.distributed import all_gather_raw
Tri Dao's avatar
Tri Dao committed
19
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
20
21


22
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
23
# @pytest.mark.parametrize('fused_ft_kernel', [True])
Tri Dao's avatar
Tri Dao committed
24
@pytest.mark.parametrize('optimized', [False, True])
25
# @pytest.mark.parametrize('optimized', [False])
26
@pytest.mark.parametrize('rotary', [False, True])
27
# @pytest.mark.parametrize('rotary', [False])
Tri Dao's avatar
Tri Dao committed
28
@pytest.mark.parametrize('model_name', ["gpt2"])
Tri Dao's avatar
Tri Dao committed
29
def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
Tri Dao's avatar
Tri Dao committed
30
31
32
33
34
    """Check that our implementation of GPT2 generation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    dtype = torch.float16
35
    device = 'cuda'
Tri Dao's avatar
Tri Dao committed
36
37
    rtol, atol = 3e-3, 3e-1
    config = GPT2Config.from_pretrained(model_name)
38
39
40
    if rotary:
        config.n_positions = 0
        config.rotary_emb_dim = 64
41
    config.residual_in_fp32 = True
Tri Dao's avatar
Tri Dao committed
42
43
44
    if optimized:
        config.use_flash_attn = True
        config.fused_bias_fc = True
45
        config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
46
47
        config.fused_dropout_add_ln = True

48
49
    # if not rotary, we load the weight from HF but ignore the position embeddings.
    # The model would be nonsense but it doesn't matter for the test.
50
51
    model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
                                           dtype=dtype)
Tri Dao's avatar
Tri Dao committed
52
    model.eval()
53
54

    if not rotary:
55
        model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
Tri Dao's avatar
Tri Dao committed
56
57
        model_hf = GPT2LMHeadModelHF.from_pretrained(model_name,
                                                     torch_dtype=dtype).to(device=device)
58
59
        model_ref.eval()
        model_hf.eval()
Tri Dao's avatar
Tri Dao committed
60
61
62

    torch.manual_seed(0)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
Tri Dao's avatar
Tri Dao committed
63
    input_ids = tokenizer("Hello, my dog is cute and",
64
                          return_tensors="pt").input_ids.to(device=device)
Tri Dao's avatar
Tri Dao committed
65
    max_length = 30
66
    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
Tri Dao's avatar
Tri Dao committed
67
    # max_length = input_ids.shape[1] + 40
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    # Slow generation for reference
    sequences = []
    scores = []
    cur_input_ids = input_ids
    with torch.inference_mode():
        scores.append(model(cur_input_ids).logits[:, -1])
        sequences.append(scores[-1].argmax(dim=-1))
        for _ in range(input_ids.shape[1] + 1, max_length):
            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
            scores.append(model(cur_input_ids).logits[:, -1])
            sequences.append(scores[-1].argmax(dim=-1))
    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
    scores = tuple(scores)

    out = model.generate(input_ids=input_ids, max_length=max_length,
84
                         fused_ft_kernel=fused_ft_kernel,
Tri Dao's avatar
Tri Dao committed
85
86
                         return_dict_in_generate=True, output_scores=True, timing=True)
    print(out.sequences)
87
    print(tokenizer.batch_decode(out.sequences.tolist()))
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
    if fused_ft_kernel:
        out_cg = model.generate(input_ids=input_ids, max_length=max_length,
                                fused_ft_kernel=fused_ft_kernel, cg=True,
                                return_dict_in_generate=True, output_scores=True, timing=True)
        print(out_cg.sequences)
Tri Dao's avatar
Tri Dao committed
93

94
95
96
97
98
    if not rotary:
        out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
                                return_dict_in_generate=True, output_scores=True)
        out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
                                    return_dict_in_generate=True, output_scores=True)
99
100
101
102
103

        print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
        print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
        print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
        print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
104
        print(tokenizer.batch_decode(out_ref.sequences.tolist()))
Tri Dao's avatar
Tri Dao committed
105
106
107
108

    assert torch.all(out.sequences == sequences)
    assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
                          rtol=rtol, atol=atol)
109
110
111
    if not rotary:
        assert torch.all(out.sequences == out_ref.sequences)
        assert torch.all(out.sequences == out_hf.sequences)
Tri Dao's avatar
Tri Dao committed
112

113
        assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
Tri Dao's avatar
Tri Dao committed
114
115
116


@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
117
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def test_greedy_decode_opt(model_name):
    """Check that our implementation of OPT generation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    print(f'\nMODEL: {model_name}')
    verbose = False
    dtype = torch.float16
    device = 'cuda'
    rtol, atol = 3e-3, 3e-1
    fused_ft_kernel = True
    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
    # Only prenorm supports residual_in_fp32
    config.residual_in_fp32 = getattr(config, 'prenorm', True)
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = True
    config.fused_dropout_add_ln = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    # OPT tokenizer requires use_fast=False
    # https://huggingface.co/docs/transformers/model_doc/opt
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    eos_token_id = tokenizer.eos_token_id

    input_ids = tokenizer("Hello, my dog is cute and",
                          return_tensors="pt").input_ids.to(device=device)
148
    max_length = 60
Tri Dao's avatar
Tri Dao committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
    # max_length = input_ids.shape[1] + 40

    # Slow generation for reference
    sequences = []
    scores = []
    cur_input_ids = input_ids
    with torch.inference_mode():
        scores.append(model(cur_input_ids).logits[:, -1])
        sequences.append(scores[-1].argmax(dim=-1))
        for _ in range(input_ids.shape[1] + 1, max_length):
            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
            scores.append(model(cur_input_ids).logits[:, -1])
            sequences.append(scores[-1].argmax(dim=-1))
            if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
                break
    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
    scores = tuple(scores)

    print('Without CUDA graph')
    torch.cuda.synchronize()
    start = time.time()
    out = model.generate(input_ids=input_ids, max_length=max_length,
                         eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel,
                         return_dict_in_generate=True, output_scores=True, timing=True)
    torch.cuda.synchronize()
    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
    if verbose:
        print(out.sequences)
    print(tokenizer.batch_decode(out.sequences.tolist()))
    if fused_ft_kernel:
        # Capture graph outside the timing loop
        batch_size, seqlen_og = input_ids.shape
        model._decoding_cache = update_graph_cache(
            model, None, batch_size, seqlen_og, max_length
        )
        print('With CUDA graph')
        torch.cuda.synchronize()
        start = time.time()
        out_cg = model.generate(input_ids=input_ids, max_length=max_length,
                                fused_ft_kernel=fused_ft_kernel, cg=True,
                                return_dict_in_generate=True, output_scores=True, timing=True)
        torch.cuda.synchronize()
        print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
        if verbose:
            print(out_cg.sequences)
195
        print(tokenizer.batch_decode(out_cg.sequences.tolist()))
Tri Dao's avatar
Tri Dao committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    del model

    model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
    out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
                               return_dict_in_generate=True, output_scores=True)
    torch.cuda.synchronize()
    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
    del model_hf

    model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
    model_ref.eval()
    print("HF fp32")
    torch.cuda.synchronize()
    start = time.time()
    out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
                                return_dict_in_generate=True, output_scores=True)
    torch.cuda.synchronize()
    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
    del model_ref
    print(tokenizer.batch_decode(out_ref.sequences.tolist()))

    if verbose:
        print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
        print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
        print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
        print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')

    assert torch.all(out.sequences == sequences)
    assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
                          rtol=rtol, atol=atol)
    assert torch.all(out.sequences == out_ref.sequences)
    assert torch.all(out.sequences == out_hf.sequences)

    assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()