test_mha_parallel.py 5.54 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py

import math

Tri Dao's avatar
Tri Dao committed
6
import pytest
Tri Dao's avatar
Tri Dao committed
7
8
import torch
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
9
from apex.transformer import parallel_state, tensor_parallel
Tri Dao's avatar
Tri Dao committed
10
11
12
from einops import rearrange
from flash_attn.modules.mha import MHA, ParallelMHA

Tri Dao's avatar
Tri Dao committed
13
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
Tri Dao's avatar
Tri Dao committed
14
15


Tri Dao's avatar
Tri Dao committed
16
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
Tri Dao's avatar
Tri Dao committed
17
# @pytest.mark.parametrize('dtype', [torch.float16])
Tri Dao's avatar
Tri Dao committed
18
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
Tri Dao's avatar
Tri Dao committed
19
# @pytest.mark.parametrize('world_size', [2])
Tri Dao's avatar
Tri Dao committed
20
@pytest.mark.parametrize("sequence_parallel", [True, False])
21
# @pytest.mark.parametrize('sequence_parallel', [False])
Tri Dao's avatar
Tri Dao committed
22
@pytest.mark.parametrize("head_dim", [64, 128])
Tri Dao's avatar
Tri Dao committed
23
# @pytest.mark.parametrize('head_dim', [64])
Tri Dao's avatar
Tri Dao committed
24
@pytest.mark.parametrize("embed_dim", [1024, 4096])
Tri Dao's avatar
Tri Dao committed
25
# @pytest.mark.parametrize('embed_dim', [1024])
26
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
Tri Dao's avatar
Tri Dao committed
27
28
29
30
31
    assert embed_dim % head_dim == 0
    num_heads = embed_dim // head_dim
    assert num_heads % world_size == 0
    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
32
33
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
34
35
36
37
38
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    # set seed
    torch.random.manual_seed(0)
39
    batch_size = 2
Tri Dao's avatar
Tri Dao committed
40
41
    seqlen = 1024
    assert (batch_size * seqlen) % world_size == 0
Tri Dao's avatar
Tri Dao committed
42
43
44
    x_pt = torch.randn(
        batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
    )
Tri Dao's avatar
Tri Dao committed
45
46
47
48
    # We need to generate g here so that all processes get the same gradient,
    # as rank 0 will have an extra bias that changes the RNG.
    # If we don't divide by batch_size, the gradient gets a bit too large.
    g = torch.randn_like(x_pt) / 32
49
    if sequence_parallel:
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
        x = (
            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
            .detach()
            .clone()
            .requires_grad_()
        )
56
57
    else:
        x = x_pt.detach().clone().requires_grad_()
Tri Dao's avatar
Tri Dao committed
58

Tri Dao's avatar
Tri Dao committed
59
60
61
62
63
64
65
66
    model_pt = MHA(
        embed_dim,
        num_heads,
        rotary_emb_dim=int(head_dim // 2),
        use_flash_attn=True,
        device=device,
        dtype=dtype,
    )
Tri Dao's avatar
Tri Dao committed
67
    partition_dim = embed_dim // world_size
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
    model = ParallelMHA(
        embed_dim,
        num_heads,
        parallel_state.get_tensor_model_parallel_group(),
        rotary_emb_dim=int(head_dim // 2),
        use_flash_attn=True,
        sequence_parallel=sequence_parallel,
        device=device,
        dtype=dtype,
    )
Tri Dao's avatar
Tri Dao committed
78
79
80

    with torch.no_grad():
        model.Wqkv.weight.copy_(
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
            rearrange(
                rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
                    :, rank * partition_dim : (rank + 1) * partition_dim
                ],
                "three o i -> (three o) i",
            )
Tri Dao's avatar
Tri Dao committed
87
88
        )
        model.Wqkv.bias.copy_(
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
            rearrange(
                rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
                    :, rank * partition_dim : (rank + 1) * partition_dim
                ],
                "three o -> (three o)",
            )
Tri Dao's avatar
Tri Dao committed
95
96
        )
        model.out_proj.weight.copy_(
Tri Dao's avatar
Tri Dao committed
97
            model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
        )
        if rank == 0:
            model.out_proj.bias.copy_(model_pt.out_proj.bias)

    out = model(x, seqlen=seqlen)
Tri Dao's avatar
Tri Dao committed
103
    out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
Tri Dao's avatar
Tri Dao committed
104
105
    partition_batch_dim = batch_size * seqlen // world_size
    assert torch.allclose(
106
        out,
Tri Dao's avatar
Tri Dao committed
107
108
109
110
111
        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else out_pt,
        rtol=rtol,
        atol=atol,
Tri Dao's avatar
Tri Dao committed
112
113
114
    )

    out_pt.backward(g)
Tri Dao's avatar
Tri Dao committed
115
116
117
    out.backward(
        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
    )
Tri Dao's avatar
Tri Dao committed
118
119
120
    parallel_state.destroy_model_parallel()

    assert torch.allclose(
121
        x.grad,
Tri Dao's avatar
Tri Dao committed
122
123
124
125
126
        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
        if sequence_parallel
        else x_pt.grad,
        rtol=rtol,
        atol=atol / 100,  # magnitude of x.grad is quite small
Tri Dao's avatar
Tri Dao committed
127
128
129
130
    )
    # The error for d_weight and d_bias is quite a bit higher
    assert torch.allclose(
        model.Wqkv.weight.grad,
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
138
        rearrange(
            rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
                :, rank * partition_dim : (rank + 1) * partition_dim
            ],
            "three o i -> (three o) i",
        ),
        rtol=rtol,
        atol=atol * 10,
Tri Dao's avatar
Tri Dao committed
139
140
141
    )
    assert torch.allclose(
        model.Wqkv.bias.grad,
Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
149
        rearrange(
            rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
                :, rank * partition_dim : (rank + 1) * partition_dim
            ],
            "three o -> (three o)",
        ),
        rtol=rtol,
        atol=atol * 5,
Tri Dao's avatar
Tri Dao committed
150
151
152
    )
    assert torch.allclose(
        model.out_proj.weight.grad,
Tri Dao's avatar
Tri Dao committed
153
154
155
        model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
        rtol=rtol,
        atol=atol * 10,
Tri Dao's avatar
Tri Dao committed
156
157
    )
    if rank == 0:
Tri Dao's avatar
Tri Dao committed
158
159
160
        assert torch.allclose(
            model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
        )