test_block_parallel.py 9.85 KB
Newer Older
1
2
3
4
5
6
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_block_parallel.py

import math
from functools import partial

Tri Dao's avatar
Tri Dao committed
7
import pytest
8
9
10
import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
11
from apex.transformer import parallel_state, tensor_parallel
12
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
13
from flash_attn.modules.block import Block
14
from flash_attn.modules.mha import MHA, ParallelMHA
15
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
16
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
17

Tri Dao's avatar
Tri Dao committed
18
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
19
20


Tri Dao's avatar
Tri Dao committed
21
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
22
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
23
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
24
# @pytest.mark.parametrize('world_size', [2])
Tri Dao's avatar
Tri Dao committed
25
@pytest.mark.parametrize("sequence_parallel", [True, False])
26
# @pytest.mark.parametrize('sequence_parallel', [True])
Tri Dao's avatar
Tri Dao committed
27
@pytest.mark.parametrize("dim", [1024])
28
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
29
30
31
32
33
34
    head_dim = 64
    assert dim % head_dim == 0
    num_heads = dim // head_dim
    assert num_heads % world_size == 0
    rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
35
36
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
37
38
39
40
41
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    # set seed
    torch.random.manual_seed(0)
42
    batch_size = 2
43
44
    seqlen = 1024
    assert (batch_size * seqlen) % world_size == 0
Tri Dao's avatar
Tri Dao committed
45
    x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
46
47
48
49
50
    residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
    # We need to generate g here so that all processes get the same gradient,
    # as rank 0 will have an extra bias that changes the RNG.
    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(x_pt) / 32
51
    if sequence_parallel:
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57
58
59
60
61
62
63
        x = (
            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
            .detach()
            .clone()
            .requires_grad_()
        )
        residual = (
            tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
            .detach()
            .clone()
            .requires_grad_()
        )
64
65
66
    else:
        x = x_pt.detach().clone().requires_grad_()
        residual = residual_pt.detach().clone().requires_grad_()
67

Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
    mixer_cls_pt = partial(
        MHA,
        num_heads=num_heads,
        rotary_emb_dim=int(head_dim // 2),
        use_flash_attn=True,
        device=device,
        dtype=dtype,
    )
76
    mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
77
78
79
80
81
82
83
84
    norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
    model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
    with torch.no_grad():
        nn.init.normal_(model_pt.norm1.weight)
        nn.init.normal_(model_pt.norm1.bias)
        nn.init.normal_(model_pt.norm2.weight)
        nn.init.normal_(model_pt.norm2.bias)

Tri Dao's avatar
Tri Dao committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    mixer_cls = partial(
        ParallelMHA,
        num_heads=num_heads,
        process_group=parallel_state.get_tensor_model_parallel_group(),
        rotary_emb_dim=int(head_dim // 2),
        use_flash_attn=True,
        sequence_parallel=sequence_parallel,
        device=device,
        dtype=dtype,
    )
    mlp_cls = partial(
        ParallelFusedMLP,
        hidden_features=4 * dim,
        process_group=parallel_state.get_tensor_model_parallel_group(),
        sequence_parallel=sequence_parallel,
        device=device,
        dtype=dtype,
    )
    model = Block(
        dim,
        mixer_cls,
        mlp_cls,
        norm_cls,
        fused_dropout_add_ln=True,
        sequence_parallel=sequence_parallel,
        mark_shared_params=True,
    )
112
113
114
115
116

    partition_dim = dim // world_size
    partition_hidden_dim = 4 * dim // world_size
    with torch.no_grad():
        model.mixer.Wqkv.weight.copy_(
Tri Dao's avatar
Tri Dao committed
117
118
119
120
121
122
            rearrange(
                rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
                    :, rank * partition_dim : (rank + 1) * partition_dim
                ],
                "three o i -> (three o) i",
            )
123
124
        )
        model.mixer.Wqkv.bias.copy_(
Tri Dao's avatar
Tri Dao committed
125
126
127
128
129
130
            rearrange(
                rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
                    :, rank * partition_dim : (rank + 1) * partition_dim
                ],
                "three o -> (three o)",
            )
131
132
        )
        model.mixer.out_proj.weight.copy_(
Tri Dao's avatar
Tri Dao committed
133
            model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
134
135
136
137
        )
        if rank == 0:
            model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
        model.mlp.fc1.weight.copy_(
Tri Dao's avatar
Tri Dao committed
138
            model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
139
140
        )
        model.mlp.fc1.bias.copy_(
Tri Dao's avatar
Tri Dao committed
141
            model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
142
143
        )
        model.mlp.fc2.weight.copy_(
Tri Dao's avatar
Tri Dao committed
144
145
146
            model_pt.mlp.fc2.weight[
                :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
            ]
147
148
149
150
151
152
153
154
        )
        if rank == 0:
            model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
        model.norm1.weight.copy_(model_pt.norm1.weight)
        model.norm1.bias.copy_(model_pt.norm1.bias)
        model.norm2.weight.copy_(model_pt.norm2.weight)
        model.norm2.bias.copy_(model_pt.norm2.bias)

Tri Dao's avatar
Tri Dao committed
155
    mixer_kwargs = {"seqlen": seqlen}
156
    out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
Tri Dao's avatar
Tri Dao committed
157
158
159
160
161
    out_pt, out_residual_pt = model_pt(
        rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
        rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
    )
    out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
162
163
    partition_batch_dim = batch_size * seqlen // world_size
    assert torch.allclose(
164
        out,
Tri Dao's avatar
Tri Dao committed
165
166
167
168
169
        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else out_pt,
        rtol=rtol,
        atol=atol,
170
171
    )
    assert torch.allclose(
172
        out_residual,
Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
        out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else out_residual_pt,
        rtol=rtol,
        atol=atol,
178
179
    )

180
    (out_pt + 2 * out_residual_pt).backward(g)
Tri Dao's avatar
Tri Dao committed
181
182
183
    (out + 2 * out_residual).backward(
        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
    )
184
    allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
185
186
187
    parallel_state.destroy_model_parallel()

    assert torch.allclose(
188
        x.grad,
Tri Dao's avatar
Tri Dao committed
189
190
191
192
193
        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else x_pt.grad,
        rtol=rtol,
        atol=atol / 10,  # magnitude of x.grad is quite small
194
195
    )
    assert torch.allclose(
196
        residual.grad,
Tri Dao's avatar
Tri Dao committed
197
198
199
200
201
        residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else residual_pt.grad,
        rtol=rtol,
        atol=atol,
202
203
204
205
    )
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(
        model.mixer.Wqkv.weight.grad,
Tri Dao's avatar
Tri Dao committed
206
207
208
209
210
211
212
213
        rearrange(
            rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
                :, rank * partition_dim : (rank + 1) * partition_dim
            ],
            "three o i -> (three o) i",
        ),
        rtol=rtol,
        atol=atol * 10,
214
215
216
    )
    assert torch.allclose(
        model.mixer.Wqkv.bias.grad,
Tri Dao's avatar
Tri Dao committed
217
218
219
220
221
222
223
224
        rearrange(
            rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
                :, rank * partition_dim : (rank + 1) * partition_dim
            ],
            "three o -> (three o)",
        ),
        rtol=rtol,
        atol=atol * 5,
225
226
227
    )
    assert torch.allclose(
        model.mixer.out_proj.weight.grad,
Tri Dao's avatar
Tri Dao committed
228
229
230
        model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
        rtol=rtol,
        atol=atol * 10,
231
232
    )
    if rank == 0:
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
238
        assert torch.allclose(
            model.mixer.out_proj.bias.grad,
            model_pt.mixer.out_proj.bias.grad,
            rtol=rtol,
            atol=atol * 5,
        )
239
240
    assert torch.allclose(
        model.mlp.fc1.weight.grad,
Tri Dao's avatar
Tri Dao committed
241
242
243
244
245
        model_pt.mlp.fc1.weight.grad[
            rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
        ],
        rtol=rtol,
        atol=atol * 10,
246
247
248
    )
    assert torch.allclose(
        model.mlp.fc1.bias.grad,
Tri Dao's avatar
Tri Dao committed
249
250
251
        model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
        rtol=rtol,
        atol=atol * 5,
252
253
254
    )
    assert torch.allclose(
        model.mlp.fc2.weight.grad,
Tri Dao's avatar
Tri Dao committed
255
256
257
258
259
        model_pt.mlp.fc2.weight.grad[
            :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
        ],
        rtol=rtol,
        atol=atol * 10,
260
261
    )
    if rank == 0:
Tri Dao's avatar
Tri Dao committed
262
263
264
        assert torch.allclose(
            model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
        )
265

Tri Dao's avatar
Tri Dao committed
266
267
268
    assert torch.allclose(
        model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
    )
269
    assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
Tri Dao's avatar
Tri Dao committed
270
271
272
    assert torch.allclose(
        model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
    )
273
    assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)