test_gpt.py 5.37 KB
Newer Older
1
2
3
import re

import pytest
Tri Dao's avatar
Tri Dao committed
4
5
6
import torch
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
7
8
9
10
from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF


Tri Dao's avatar
Tri Dao committed
11
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
12
13
14
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
    config = GPT2Config.from_pretrained(model_name)
Tri Dao's avatar
Tri Dao committed
15
    pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
16
17
18
19
20
21
22
    model = GPTLMHeadModel(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


Tri Dao's avatar
Tri Dao committed
23
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):
    """Check that our implementation of GPT2 (without any optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = GPT2Config.from_pretrained(model_name)

    model = GPTLMHeadModel.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
46
47
48
49
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
50
51
52
53
    out = model.transformer(input_ids)
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    out_ref = model_ref.transformer(input_ids).last_hidden_state

Tri Dao's avatar
Tri Dao committed
54
55
56
57
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
58
59
60
61
62
63
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
69
70
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
71
72


Tri Dao's avatar
Tri Dao committed
73
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
74
75
76
77
78
79
80
81
82
83
84
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name):
    """Check that our implementation of GPT2 (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = GPT2Config.from_pretrained(model_name)
    vocab_size_og = config.vocab_size
    config.use_flash_attn = True
    config.fused_bias_fc = True
85
    config.fused_mlp = True
86
    config.fused_dropout_add_ln = True
Tri Dao's avatar
Tri Dao committed
87
    config.residual_in_fp32 = True
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    config.pad_vocab_size_multiple = 8

    model = GPTLMHeadModel.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
Tri Dao's avatar
Tri Dao committed
103
104
105
106
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    input_ids = torch.randint(
        0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
107
108
109
110
    out = model.transformer(input_ids)
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    out_ref = model_ref.transformer(input_ids).last_hidden_state

Tri Dao's avatar
Tri Dao committed
111
112
113
114
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
115
116
117
118
119
120
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits[..., :vocab_size_og]
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()