"test/srt/quant/test_w4a8_deepseek_v3.py" did not exist on "daed453e84ba4f30681f8a458e522ad1249d10af"
test_gpt_parallel.py 8.96 KB
Newer Older
1
2
3
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py

4
5
import math

Tri Dao's avatar
Tri Dao committed
6
import pytest
7
8
9
10
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
Tri Dao's avatar
Tri Dao committed
11
from einops import rearrange
12
from flash_attn.losses.cross_entropy import CrossEntropyLoss
Tri Dao's avatar
Tri Dao committed
13
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
14
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
Tri Dao's avatar
Tri Dao committed
15
from transformers import GPT2Config
16

Tri Dao's avatar
Tri Dao committed
17
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
18
19


Tri Dao's avatar
Tri Dao committed
20
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
21
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
22
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
23
# @pytest.mark.parametrize('world_size', [2])
Tri Dao's avatar
Tri Dao committed
24
@pytest.mark.parametrize("sequence_parallel", [True, False])
25
# @pytest.mark.parametrize('sequence_parallel', [False])
Tri Dao's avatar
Tri Dao committed
26
@pytest.mark.parametrize("has_pos_emb", [True, False])
27
# @pytest.mark.parametrize('has_pos_emb', [True])
Tri Dao's avatar
Tri Dao committed
28
@pytest.mark.parametrize("dim", [1024])
29
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
30
31
32
33
34
35
36
37
38
    head_dim = 64
    assert dim % head_dim == 0
    num_heads = dim // head_dim
    assert num_heads % world_size == 0
    vocab_size = 50264
    assert vocab_size % world_size == 0
    num_layers = 2
    rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
39
40
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 1024
    assert (batch_size * seqlen) % world_size == 0
    input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)

    # We need to generate g here so that all processes get the same gradient,
    # as rank 0 will have an extra bias that changes the RNG.
    g = torch.randn(batch_size * seqlen, device=device)

Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    config = GPT2Config(
        n_embd=dim,
        n_head=num_heads,
        n_layer=num_layers,
        n_positions=seqlen if has_pos_emb else 0,
        vocab_size=50257,
        resid_pdrop=0.0,
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        scale_attn_by_inverse_layer_idx=True,
        use_flash_attn=True,
        fused_mlp=True,
        fused_bias_fc=True,
        fused_dropout_add_ln=True,
        residual_in_fp32=True,
        rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
        pad_vocab_size_multiple=8 * world_size,
        sequence_parallel=sequence_parallel,
    )
75
    config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
76
77
78
79
80
81
    model_pt = GPTLMHeadModel(config, device=device)

    def init_layer_norm(module):
        if isinstance(module, nn.LayerNorm):
            nn.init.normal_(module.weight)
            nn.init.normal_(module.bias)
Tri Dao's avatar
Tri Dao committed
82

83
84
85
86
87
88
89
90
91
    model_pt.apply(init_layer_norm)

    model = GPTLMHeadModel(config, process_group=process_group, device=device)
    total_nparams = sum(p.numel() for p in model_pt.parameters())
    sharded_nparams = sum(p.numel() for p in model.parameters())
    sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
    torch.distributed.all_gather_into_tensor(
        sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
    )
Tri Dao's avatar
Tri Dao committed
92
93
94
    shared_nparams = sum(
        p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
    )
95
    shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
96
    torch.distributed.all_gather_into_tensor(
97
        shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
98
    )
99
    assert torch.all(shared_nparams_all == shared_nparams)
Tri Dao's avatar
Tri Dao committed
100
101
102
    assert total_nparams == (
        (sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
    )
103
104
105
106
107
108

    # vocab_size has been rounded up here
    partition_vocab_size = config.vocab_size // world_size
    partition_dim = dim // world_size
    partition_hidden_dim = 4 * dim // world_size
    with torch.no_grad():
109
110
        model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
        model.tie_weights()
111

Tri Dao's avatar
Tri Dao committed
112
    with torch.autocast(device_type="cuda", dtype=dtype):
113
        out = model(input_ids[:, :-1]).logits
114
        if not sequence_parallel:
Tri Dao's avatar
Tri Dao committed
115
116
            out = rearrange(out, "b s d -> (b s) d")
        out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
117
118
    partition_batch_dim = batch_size * seqlen // world_size
    assert torch.allclose(
Tri Dao's avatar
Tri Dao committed
119
120
121
122
        out,
        out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
        rtol=rtol,
        atol=atol,
123
    )
Tri Dao's avatar
Tri Dao committed
124
125
    loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
    loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
126
127
128
129
130
131
132
133
134
    loss = loss_fn(out, input_ids[:, 1:].flatten())
    loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
    assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)

    loss_pt.backward(g)
    loss.backward(g)
    allreduce_sequence_parallel_grad(model, process_group)
    parallel_state.destroy_model_parallel()

Tri Dao's avatar
Tri Dao committed
135
136
137
    grad_dict = shard_state_dict_tp(
        {k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
    )
138

139
140
    assert torch.allclose(
        model.transformer.embeddings.word_embeddings.weight.grad,
Tri Dao's avatar
Tri Dao committed
141
142
143
        grad_dict["transformer.embeddings.word_embeddings.weight"],
        rtol=rtol,
        atol=atol * 5,
144
145
146
147
    )
    if has_pos_emb:
        assert torch.allclose(
            model.transformer.embeddings.position_embeddings.weight.grad,
Tri Dao's avatar
Tri Dao committed
148
149
150
            grad_dict["transformer.embeddings.position_embeddings.weight"],
            rtol=rtol,
            atol=atol,
151
        )
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
    assert torch.allclose(
        model.transformer.ln_f.weight.grad,
        grad_dict["transformer.ln_f.weight"],
        rtol=rtol,
        atol=atol,
    )
    assert torch.allclose(
        model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
    )
161
162
163
    for i in range(num_layers):
        assert torch.allclose(
            model.transformer.layers[i].mixer.Wqkv.weight.grad,
Tri Dao's avatar
Tri Dao committed
164
165
166
            grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
            rtol=rtol,
            atol=atol * 10,
167
168
169
        )
        assert torch.allclose(
            model.transformer.layers[i].mixer.Wqkv.bias.grad,
Tri Dao's avatar
Tri Dao committed
170
171
172
            grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
            rtol=rtol,
            atol=atol * 10,
173
174
175
        )
        assert torch.allclose(
            model.transformer.layers[i].mixer.out_proj.weight.grad,
Tri Dao's avatar
Tri Dao committed
176
177
178
            grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
            rtol=rtol,
            atol=atol * 10,
179
180
        )
        if rank == 0:
Tri Dao's avatar
Tri Dao committed
181
182
183
184
185
186
            assert torch.allclose(
                model.transformer.layers[i].mixer.out_proj.bias.grad,
                grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
                rtol=rtol,
                atol=atol * 5,
            )
187
188
        assert torch.allclose(
            model.transformer.layers[i].mlp.fc1.weight.grad,
Tri Dao's avatar
Tri Dao committed
189
190
191
            grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
            rtol=rtol,
            atol=atol * 10,
192
193
194
        )
        assert torch.allclose(
            model.transformer.layers[i].mlp.fc1.bias.grad,
Tri Dao's avatar
Tri Dao committed
195
196
197
            grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
            rtol=rtol,
            atol=atol * 10,
198
199
200
        )
        assert torch.allclose(
            model.transformer.layers[i].mlp.fc2.weight.grad,
Tri Dao's avatar
Tri Dao committed
201
202
203
            grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
            rtol=rtol,
            atol=atol * 10,
204
205
        )
        if rank == 0:
Tri Dao's avatar
Tri Dao committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            assert torch.allclose(
                model.transformer.layers[i].mlp.fc2.bias.grad,
                grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
                rtol=rtol,
                atol=atol * 5,
            )

        assert torch.allclose(
            model.transformer.layers[i].norm1.weight.grad,
            grad_dict[f"transformer.layers.{i}.norm1.weight"],
            rtol=rtol,
            atol=atol,
        )
        assert torch.allclose(
            model.transformer.layers[i].norm1.bias.grad,
            grad_dict[f"transformer.layers.{i}.norm1.bias"],
            rtol=rtol,
            atol=atol,
        )
        assert torch.allclose(
            model.transformer.layers[i].norm2.weight.grad,
            grad_dict[f"transformer.layers.{i}.norm2.weight"],
            rtol=rtol,
            atol=atol,
        )
        assert torch.allclose(
            model.transformer.layers[i].norm2.bias.grad,
            grad_dict[f"transformer.layers.{i}.norm2.bias"],
            rtol=rtol,
            atol=atol,
        )