test_gpt.py 5.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
import re

import torch
import pytest

from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF

from flash_attn.models.gpt import GPTLMHeadModel
Tri Dao's avatar
Tri Dao committed
10
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
11
12
13
14
15
16
17
from flash_attn.utils.pretrained import state_dict_from_pretrained


@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
    config = GPT2Config.from_pretrained(model_name)
Tri Dao's avatar
Tri Dao committed
18
    pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    model = GPTLMHeadModel(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):
    """Check that our implementation of GPT2 (without any optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = GPT2Config.from_pretrained(model_name)

    model = GPTLMHeadModel.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device='cuda')
    out = model.transformer(input_ids)
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    out_ref = model_ref.transformer(input_ids).last_hidden_state

    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

    print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
    print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
    assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()


@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name):
    """Check that our implementation of GPT2 (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    config = GPT2Config.from_pretrained(model_name)
    vocab_size_og = config.vocab_size
    config.use_flash_attn = True
    config.fused_bias_fc = True
85
    config.fused_mlp = True
86
    config.fused_dropout_add_ln = True
Tri Dao's avatar
Tri Dao committed
87
    config.residual_in_fp32 = True
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    config.pad_vocab_size_multiple = 8

    model = GPTLMHeadModel.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
    input_ids = torch.randint(0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long,
                              device='cuda')
    out = model.transformer(input_ids)
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    out_ref = model_ref.transformer(input_ids).last_hidden_state

    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits[..., :vocab_size_og]
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

    print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
    print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
    assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()