test_gptj.py 7.12 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
import time
Tri Dao's avatar
Tri Dao committed
4
5

import pytest
Tri Dao's avatar
Tri Dao committed
6
import torch
Tri Dao's avatar
Tri Dao committed
7
from flash_attn.models.gpt import GPTLMHeadModel
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
Tri Dao's avatar
Tri Dao committed
9
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
10
11
12
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
Tri Dao's avatar
Tri Dao committed
13
14


Tri Dao's avatar
Tri Dao committed
15
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
Tri Dao's avatar
Tri Dao committed
16
17
18
def test_gptj_state_dict(model_name):
    config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
    pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
Tri Dao's avatar
Tri Dao committed
19
    model = GPTLMHeadModel(config, device="meta")  # Without device='meta' init is very slow
Tri Dao's avatar
Tri Dao committed
20
    state_dict = model.state_dict()
21
22
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
Tri Dao's avatar
Tri Dao committed
23
24
25
        assert state_dict[k].shape == pretrained_state_dict[k].shape


26
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"])
Tri Dao's avatar
Tri Dao committed
27
28
29
30
31
32
def test_gptj_optimized(model_name):
    """Check that our implementation of GPT-J (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
33
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
34
    config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
35
    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256
Tri Dao's avatar
Tri Dao committed
36
37
    config.fused_bias_fc = True
    config.fused_mlp = True
38
    config.fused_dropout_add_ln = True
Tri Dao's avatar
Tri Dao committed
39
40
41
42
43
44
45
46
    config.residual_in_fp32 = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
Tri Dao's avatar
Tri Dao committed
47
48
49
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

Tri Dao's avatar
Tri Dao committed
55
56
    # Without device_map, the model is loaded on the CPU, which is very slow
    model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62
    model_ref.eval()
    with torch.no_grad():
        out_ref = model_ref.transformer(input_ids).last_hidden_state
        logits_ref = model_ref(input_ids).logits
    del model_ref

Tri Dao's avatar
Tri Dao committed
63
64
65
    model_hf = GPTJForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}
    )
Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
    model_hf.eval()
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    logits_hf = model_hf(input_ids).logits
    del model_hf

Tri Dao's avatar
Tri Dao committed
71
72
73
74
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
75
76
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
84
85


Tri Dao's avatar
Tri Dao committed
86
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
87
def test_gptj_generation(model_name):
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
    """Check that our implementation of GPT-J (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
93
    device = "cuda"
Tri Dao's avatar
Tri Dao committed
94
    config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
95
    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
102
103
104
105
106
107
108
    config.fused_bias_fc = True
    config.fused_mlp = True
    config.fused_dropout_add_ln = True
    # Only prenorm supports residual_in_fp32
    config.residual_in_fp32 = True

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    eos_token_id = tokenizer.eos_token_id

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
Tri Dao's avatar
Tri Dao committed
109
110
111
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
112

Tri Dao's avatar
Tri Dao committed
113
114
115
    model_hf = GPTJForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}
    )
Tri Dao's avatar
Tri Dao committed
116
117
118
119
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
120
121
122
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
123
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
124
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
125
126
127
128
129
    del model_hf

    model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
    model_ref.eval()
    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
130
        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
    del model_ref

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

Tri Dao's avatar
Tri Dao committed
136
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
137
138
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
139
140
141
142
143
144
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
145
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
146
147
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
148
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
149
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
150
151
152

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
153
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
Tri Dao's avatar
Tri Dao committed
154
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
155
156
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
157
158
159
160
161
162
    out_cg = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
163
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
164
165
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
166
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
167
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
168
169

    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
170
        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
176
177
178
179
    logits_hf = torch.stack(out_hf.scores, dim=1)
    logits = torch.stack(out.scores, dim=1)
    logits_cg = torch.stack(out_cg.scores, dim=1)

    del model

    hf_error = (logits_hf - logits_ref).abs().max().item()
    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error

Tri Dao's avatar
Tri Dao committed
180
181
    print(f"HF fp16 logits max diff: {hf_error}")
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
182
    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
183
    print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
184
    assert torch.equal(logits_cg, logits)