test_opt.py 9.28 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
import re
2
import time
Tri Dao's avatar
Tri Dao committed
3
4

import pytest
Tri Dao's avatar
Tri Dao committed
5
import torch
6
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
7
from flash_attn.models.gpt import GPTLMHeadModel
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
9
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
10
from flash_attn.utils.pretrained import state_dict_from_pretrained
11
from transformers import AutoTokenizer, OPTConfig
Tri Dao's avatar
Tri Dao committed
12
from transformers.models.opt.modeling_opt import OPTForCausalLM
Tri Dao's avatar
Tri Dao committed
13
14


Tri Dao's avatar
Tri Dao committed
15
16
17
@pytest.mark.parametrize(
    "model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
Tri Dao's avatar
Tri Dao committed
18
19
20
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
Tri Dao's avatar
Tri Dao committed
21
    pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config)
Tri Dao's avatar
Tri Dao committed
22
23
24
25
26
27
28
    model = GPTLMHeadModel(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


Tri Dao's avatar
Tri Dao committed
29
30
31
@pytest.mark.parametrize(
    "model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
Tri Dao's avatar
Tri Dao committed
32
33
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
34
    """Check that our implementation of OPT (without all optimizations enabled) matches the
Tri Dao's avatar
Tri Dao committed
35
36
37
38
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
39
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
40
41
42
    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
    config.use_flash_attn = True
    config.fused_bias_fc = True
Tri Dao's avatar
Tri Dao committed
43
    config.fused_mlp = True
Tri Dao's avatar
Tri Dao committed
44
45
    config.fused_dropout_add_ln = True
    # Only prenorm supports residual_in_fp32
Tri Dao's avatar
Tri Dao committed
46
    config.residual_in_fp32 = getattr(config, "prenorm", True)
Tri Dao's avatar
Tri Dao committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    config.pad_vocab_size_multiple = 8

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)

    model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
    model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
    if model_name != "facebook/opt-350m":  # The OPT-350m projects the embeddings to dimension 512
Tri Dao's avatar
Tri Dao committed
66
67
68
69
        out = model.transformer(input_ids)
        out_hf = model_hf.model(input_ids).last_hidden_state
        out_ref = model_ref.model(input_ids).last_hidden_state

Tri Dao's avatar
Tri Dao committed
70
71
72
73
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
79
        assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
85
86
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158


@pytest.mark.parametrize(
    "model_name",
    [
        "facebook/opt-125m",
        "facebook/opt-350m",
        "facebook/opt-1.3b",
        "facebook/opt-2.7b",
        "facebook/opt-6.7b",
    ],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_opt_generation(model_name):
    """Check that our implementation of OPT generation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    print(f"\nMODEL: {model_name}")
    verbose = False
    dtype = torch.float16
    device = "cuda"
    rtol, atol = 3e-3, 3e-1
    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
    # Only prenorm supports residual_in_fp32
    config.residual_in_fp32 = getattr(config, "prenorm", True)
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = True
    config.fused_dropout_add_ln = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    # OPT tokenizer requires use_fast=False
    # https://huggingface.co/docs/transformers/model_doc/opt
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    eos_token_id = tokenizer.eos_token_id

    input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
        device=device
    )
    max_length = 25
    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
    # max_length = input_ids.shape[1] + 40

    # Slow generation for reference
    sequences = []
    scores = []
    cur_input_ids = input_ids
    with torch.inference_mode():
        scores.append(model(cur_input_ids).logits[:, -1])
        sequences.append(scores[-1].argmax(dim=-1))
        for _ in range(input_ids.shape[1] + 1, max_length):
            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
            scores.append(model(cur_input_ids).logits[:, -1])
            sequences.append(scores[-1].argmax(dim=-1))
            if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
                break
    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
    scores = tuple(scores)

    print("Without CUDA graph")
    torch.cuda.synchronize()
    start = time.time()
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
159
        enable_timing=True,
160
161
162
163
164
165
    )
    torch.cuda.synchronize()
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
    if verbose:
        print(out.sequences)
    print(tokenizer.batch_decode(out.sequences.tolist()))
166
    if getattr(config, "use_flash_attn", False):
167
168
        # Capture graph outside the timing loop
        batch_size, seqlen_og = input_ids.shape
169
        model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
170
171
172
173
174
175
176
177
178
        print("With CUDA graph")
        torch.cuda.synchronize()
        start = time.time()
        out_cg = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            cg=True,
            return_dict_in_generate=True,
            output_scores=True,
Tri Dao's avatar
Tri Dao committed
179
            enable_timing=True,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        )
        torch.cuda.synchronize()
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
        if verbose:
            print(out_cg.sequences)
        print(tokenizer.batch_decode(out_cg.sequences.tolist()))

    del model

    model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
    torch.cuda.synchronize()
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
    del model_hf

    model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
    model_ref.eval()
    print("HF fp32")
    torch.cuda.synchronize()
    start = time.time()
    out_ref = model_ref.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
    torch.cuda.synchronize()
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
    del model_ref
    print(tokenizer.batch_decode(out_ref.sequences.tolist()))

    if verbose:
        print(
            f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
        print(
            f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )

    assert torch.all(out.sequences == sequences)
    assert torch.allclose(
        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
    )
    assert torch.all(out.sequences == out_ref.sequences)
    assert torch.all(out.sequences == out_hf.sequences)

    assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
        torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
    ).abs().max().item()