test_gpt_parallel.py 8.93 KB
Newer Older
1
2
3
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py

4
5
import math

6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest

from einops import rearrange

from transformers import GPT2Config

from apex.transformer import parallel_state

17
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
18
19
20
21
22
23
24
25
26
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad

is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8


@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
27
# @pytest.mark.parametrize('world_size', [2])
28
29
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
30
31
32
@pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
33
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    head_dim = 64
    assert dim % head_dim == 0
    num_heads = dim // head_dim
    assert num_heads % world_size == 0
    vocab_size = 50264
    assert vocab_size % world_size == 0
    num_layers = 2
    rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    device = f'cuda:{torch.distributed.get_rank()}'
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 1024
    assert (batch_size * seqlen) % world_size == 0
    input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)

    # We need to generate g here so that all processes get the same gradient,
    # as rank 0 will have an extra bias that changes the RNG.
    g = torch.randn(batch_size * seqlen, device=device)

    config = GPT2Config(n_embd=dim, n_head=num_heads, n_layer=num_layers,
                        n_positions=seqlen if has_pos_emb else 0,
                        vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
                        scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
64
65
                        fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True,
                        residual_in_fp32=True,
66
                        rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
67
68
                        pad_vocab_size_multiple=8 * world_size,
                        sequence_parallel=sequence_parallel)
69
    config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    model_pt = GPTLMHeadModel(config, device=device)

    def init_layer_norm(module):
        if isinstance(module, nn.LayerNorm):
            nn.init.normal_(module.weight)
            nn.init.normal_(module.bias)
    model_pt.apply(init_layer_norm)

    model = GPTLMHeadModel(config, process_group=process_group, device=device)
    total_nparams = sum(p.numel() for p in model_pt.parameters())
    sharded_nparams = sum(p.numel() for p in model.parameters())
    sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
    torch.distributed.all_gather_into_tensor(
        sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
    )
85
86
87
    shared_nparams = sum(p.numel() for p in model.parameters()
                                    if getattr(p, '_shared_params', False))
    shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
88
    torch.distributed.all_gather_into_tensor(
89
        shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
90
    )
91
92
93
    assert torch.all(shared_nparams_all == shared_nparams)
    assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item()
                             + shared_nparams)
94
95
96
97
98
99

    # vocab_size has been rounded up here
    partition_vocab_size = config.vocab_size // world_size
    partition_dim = dim // world_size
    partition_hidden_dim = 4 * dim // world_size
    with torch.no_grad():
100
101
        model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
        model.tie_weights()
102
103
104

    with torch.autocast(device_type='cuda', dtype=dtype):
        out = model(input_ids[:, :-1]).logits
105
106
        if not sequence_parallel:
            out = rearrange(out, 'b s d -> (b s) d')
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
    partition_batch_dim = batch_size * seqlen // world_size
    assert torch.allclose(
        out, out_pt[:, rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
        rtol=rtol, atol=atol
    )
    loss_fn = CrossEntropyLoss(inplace_backward=True, reduction='none', process_group=process_group)
    loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction='none')
    loss = loss_fn(out, input_ids[:, 1:].flatten())
    loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
    assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)

    loss_pt.backward(g)
    loss.backward(g)
    allreduce_sequence_parallel_grad(model, process_group)
    parallel_state.destroy_model_parallel()

124
125
126
    grad_dict = shard_state_dict_tp({k: v.grad for k, v in model_pt.named_parameters()},
                                    config, world_size, rank)

127
128
    assert torch.allclose(
        model.transformer.embeddings.word_embeddings.weight.grad,
129
        grad_dict['transformer.embeddings.word_embeddings.weight'],
130
131
132
133
134
        rtol=rtol, atol=atol * 5
    )
    if has_pos_emb:
        assert torch.allclose(
            model.transformer.embeddings.position_embeddings.weight.grad,
135
            grad_dict['transformer.embeddings.position_embeddings.weight'],
136
137
            rtol=rtol, atol=atol
        )
138
    assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
139
                          rtol=rtol, atol=atol)
140
    assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
141
142
143
144
                          rtol=rtol, atol=atol)
    for i in range(num_layers):
        assert torch.allclose(
            model.transformer.layers[i].mixer.Wqkv.weight.grad,
145
            grad_dict[f'transformer.layers.{i}.mixer.Wqkv.weight'],
146
147
148
149
            rtol=rtol, atol=atol * 10
        )
        assert torch.allclose(
            model.transformer.layers[i].mixer.Wqkv.bias.grad,
150
            grad_dict[f'transformer.layers.{i}.mixer.Wqkv.bias'],
151
152
153
154
            rtol=rtol, atol=atol * 10
        )
        assert torch.allclose(
            model.transformer.layers[i].mixer.out_proj.weight.grad,
155
            grad_dict[f'transformer.layers.{i}.mixer.out_proj.weight'],
156
157
158
            rtol=rtol, atol=atol * 10
        )
        if rank == 0:
159
160
161
            assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad,
                                  grad_dict[f'transformer.layers.{i}.mixer.out_proj.bias'],
                                  rtol=rtol, atol=atol * 5)
162
163
        assert torch.allclose(
            model.transformer.layers[i].mlp.fc1.weight.grad,
164
            grad_dict[f'transformer.layers.{i}.mlp.fc1.weight'],
165
166
167
168
            rtol=rtol, atol=atol * 10
        )
        assert torch.allclose(
            model.transformer.layers[i].mlp.fc1.bias.grad,
169
            grad_dict[f'transformer.layers.{i}.mlp.fc1.bias'],
170
171
172
173
            rtol=rtol, atol=atol * 10
        )
        assert torch.allclose(
            model.transformer.layers[i].mlp.fc2.weight.grad,
174
            grad_dict[f'transformer.layers.{i}.mlp.fc2.weight'],
175
176
177
            rtol=rtol, atol=atol * 10
        )
        if rank == 0:
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad,
                                  grad_dict[f'transformer.layers.{i}.mlp.fc2.bias'],
                                  rtol=rtol, atol=atol * 5)

        assert torch.allclose(model.transformer.layers[i].norm1.weight.grad,
                              grad_dict[f'transformer.layers.{i}.norm1.weight'],
                              rtol=rtol, atol=atol)
        assert torch.allclose(model.transformer.layers[i].norm1.bias.grad,
                              grad_dict[f'transformer.layers.{i}.norm1.bias'],
                              rtol=rtol, atol=atol)
        assert torch.allclose(model.transformer.layers[i].norm2.weight.grad,
                              grad_dict[f'transformer.layers.{i}.norm2.weight'],
                              rtol=rtol, atol=atol)
        assert torch.allclose(model.transformer.layers[i].norm2.bias.grad,
                              grad_dict[f'transformer.layers.{i}.norm2.bias'],
                              rtol=rtol, atol=atol)