test_fused_dense_parallel.py 8.65 KB
Newer Older
1
2
3
4
5
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py

import math

Tri Dao's avatar
Tri Dao committed
6
import pytest
7
8
import torch
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
9
10
from apex.transformer import parallel_state, tensor_parallel
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP
11

Tri Dao's avatar
Tri Dao committed
12
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
13
14


Tri Dao's avatar
Tri Dao committed
15
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
16
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
17
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
18
# @pytest.mark.parametrize('world_size', [2])
Tri Dao's avatar
Tri Dao committed
19
@pytest.mark.parametrize("sequence_parallel", [True, False])
20
# @pytest.mark.parametrize('sequence_parallel', [False])
Tri Dao's avatar
Tri Dao committed
21
@pytest.mark.parametrize("has_bias", [True, False])
22
# @pytest.mark.parametrize('has_bias', [False])
Tri Dao's avatar
Tri Dao committed
23
24
25
26
27
@pytest.mark.parametrize("out_features", [1024])
@pytest.mark.parametrize("in_features", [4096])
def test_fused_linear_bias(
    in_features, out_features, has_bias, sequence_parallel, world_size, dtype
):
28
29
30
    assert out_features % world_size == 0
    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
31
32
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
33
34
35
36
37
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    # set seed
    torch.random.manual_seed(0)
38
    batch_size = 2
39
40
    seqlen = 512
    assert batch_size * seqlen % world_size == 0
Tri Dao's avatar
Tri Dao committed
41
42
43
    x_pt = torch.randn(
        batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
    )
44
    if sequence_parallel:
Tri Dao's avatar
Tri Dao committed
45
46
47
48
49
50
        x = (
            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
            .detach()
            .clone()
            .requires_grad_()
        )
51
52
    else:
        x = x_pt.detach().clone().requires_grad_()
53
54
55

    model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
    partition_out_features = out_features // world_size
Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
61
62
63
64
    model = ColumnParallelLinear(
        in_features,
        out_features,
        parallel_state.get_tensor_model_parallel_group(),
        bias=has_bias,
        sequence_parallel=sequence_parallel,
        device=device,
        dtype=dtype,
    )
65
66
    with torch.no_grad():
        model.weight.copy_(
Tri Dao's avatar
Tri Dao committed
67
            model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
68
69
70
        )
        if has_bias:
            model.bias.copy_(
Tri Dao's avatar
Tri Dao committed
71
                model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
72
73
74
75
76
            )

    out = model(x)
    out_pt = model_pt(x_pt)
    assert torch.allclose(
Tri Dao's avatar
Tri Dao committed
77
78
79
80
        out,
        out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features],
        rtol=rtol,
        atol=atol,
81
82
83
84
85
    )

    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(out_pt) / 32
    out_pt.backward(g)
Tri Dao's avatar
Tri Dao committed
86
    out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features])
87
88
89
90
    parallel_state.destroy_model_parallel()

    partition_batch_dim = batch_size * seqlen // world_size
    assert torch.allclose(
91
        x.grad,
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else x_pt.grad,
        rtol=rtol,
        atol=atol,
97
98
99
100
    )
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(
        model.weight.grad,
Tri Dao's avatar
Tri Dao committed
101
102
103
        model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
        rtol=rtol,
        atol=atol * 10,
104
105
106
107
    )
    if has_bias:
        assert torch.allclose(
            model.bias.grad,
Tri Dao's avatar
Tri Dao committed
108
109
110
            model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
            rtol=rtol,
            atol=atol * 5,
111
112
113
        )


Tri Dao's avatar
Tri Dao committed
114
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
115
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
Tri Dao's avatar
Tri Dao committed
116
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
117
# @pytest.mark.parametrize('world_size', [2])
Tri Dao's avatar
Tri Dao committed
118
@pytest.mark.parametrize("sequence_parallel", [True, False])
119
# @pytest.mark.parametrize('sequence_parallel', [False])
Tri Dao's avatar
Tri Dao committed
120
@pytest.mark.parametrize("has_bias2", [True, False])
121
# @pytest.mark.parametrize('has_bias2', [True])
Tri Dao's avatar
Tri Dao committed
122
123
@pytest.mark.parametrize("out_features", [4096])
@pytest.mark.parametrize("in_features", [1024])
124
def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
125
126
127
    assert out_features % world_size == 0
    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
128
129
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
130
131
132
133
134
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    # set seed
    torch.random.manual_seed(0)
135
    batch_size = 2
136
137
    seqlen = 512
    assert batch_size * seqlen % world_size == 0
Tri Dao's avatar
Tri Dao committed
138
139
140
    x_pt = torch.randn(
        batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
    )
141
142
143
144
    # We need to generate g here so that all processes get the same gradient,
    # as rank 0 will have an extra bias that changes the RNG.
    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(x_pt) / 32
145
    if sequence_parallel:
Tri Dao's avatar
Tri Dao committed
146
147
148
149
150
151
        x = (
            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
            .detach()
            .clone()
            .requires_grad_()
        )
152
153
    else:
        x = x_pt.detach().clone().requires_grad_()
154
155

    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
156
157
158
    model_pt_fc2 = torch.nn.Linear(
        out_features, in_features, bias=has_bias2, device=device, dtype=dtype
    )
159
160
    partition_out_features = out_features // world_size
    partition_in_features = in_features // world_size
Tri Dao's avatar
Tri Dao committed
161
162
163
164
165
166
167
168
169
170
    model = ParallelFusedMLP(
        in_features,
        out_features,
        in_features,
        process_group=parallel_state.get_tensor_model_parallel_group(),
        bias2=has_bias2 and rank == 0,
        sequence_parallel=sequence_parallel,
        device=device,
        dtype=dtype,
    )
171
172
173

    with torch.no_grad():
        model.fc1.weight.copy_(
Tri Dao's avatar
Tri Dao committed
174
            model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
175
176
        )
        model.fc1.bias.copy_(
Tri Dao's avatar
Tri Dao committed
177
            model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
178
179
        )
        model.fc2.weight.copy_(
Tri Dao's avatar
Tri Dao committed
180
181
182
            model_pt_fc2.weight[
                :, rank * partition_out_features : (rank + 1) * partition_out_features
            ]
183
184
185
186
187
        )
        if has_bias2 and rank == 0:
            model.fc2.bias.copy_(model_pt_fc2.bias)

    out = model(x)
Tri Dao's avatar
Tri Dao committed
188
    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate="tanh"))
189
190
    partition_batch_dim = batch_size * seqlen // world_size
    assert torch.allclose(
191
        out,
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else out_pt,
        rtol=rtol,
        atol=atol,
197
198
199
    )

    out_pt.backward(g)
Tri Dao's avatar
Tri Dao committed
200
201
202
    out.backward(
        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
    )
203
204
205
    parallel_state.destroy_model_parallel()

    assert torch.allclose(
206
        x.grad,
Tri Dao's avatar
Tri Dao committed
207
208
209
210
211
        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else x_pt.grad,
        rtol=rtol,
        atol=atol,
212
213
214
215
    )
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(
        model.fc1.weight.grad,
Tri Dao's avatar
Tri Dao committed
216
217
218
219
220
        model_pt_fc1.weight.grad[
            rank * partition_out_features : (rank + 1) * partition_out_features
        ],
        rtol=rtol,
        atol=atol * 10,
221
222
223
    )
    assert torch.allclose(
        model.fc1.bias.grad,
Tri Dao's avatar
Tri Dao committed
224
225
226
        model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
        rtol=rtol,
        atol=atol * 5,
227
228
229
    )
    assert torch.allclose(
        model.fc2.weight.grad,
Tri Dao's avatar
Tri Dao committed
230
231
232
233
234
        model_pt_fc2.weight.grad[
            :, rank * partition_out_features : (rank + 1) * partition_out_features
        ],
        rtol=rtol,
        atol=atol * 10,
235
236
237
    )
    if has_bias2 and rank == 0:
        assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)