test_sequence.py 4.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import colossalai
import colossalai.nn as col_nn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import pytest

from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial


CONFIG = dict(
    parallel=dict(
        tensor=dict(size=4, mode='sequence')
    )
)


def check_ring_qk(rank, world_size):
    # params
    batch_size = 4
    num_heads = 4
    seq_length = 32
    attention_head_size = 32
    sub_seq_length = seq_length // world_size

    # create master tensors
    q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
    k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
    dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
    dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))

    # create distributed tensors
    sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
    sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()

    # set autograd attributes
    q.requires_grad = True
    k.requires_grad = True
    q.retain_grad()
    k.retain_grad()
    sub_q.requires_grad = True
    sub_k.requires_grad = True
    sub_q.retain_grad()
    sub_k.retain_grad()

    # compute master attention scores
    a = torch.matmul(q, k.transpose(2, 1))

    # compute distributed attention scores
    ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply
    sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)

    # check master and distributed attetion scores
    sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
    assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)

    # run master backward
    a.retain_grad()
    a.mean().backward()

    # run distributed backward
    partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
    torch.autograd.backward(sub_a, partial_master_a_grad)

    # check master and distributed grads
    partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
    assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
        'attention score cannot match'


def check_ring_av(rank, world_size):
    # params
    batch_size = 4
    num_heads = 4
    seq_length = 16
    attention_head_size = 32
    sub_seq_length = seq_length // world_size

    # create master tensors
    a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda()
    v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
    dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
    dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))

    # create distributed tensors
    sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
    sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()

    # set autograd attributes
    a.requires_grad = True
    v.requires_grad = True
    a.retain_grad()
    v.retain_grad()
    sub_a.requires_grad = True
    sub_v.requires_grad = True
    sub_a.retain_grad()
    sub_v.retain_grad()

    # compute master attention scores
    out = torch.matmul(a, v)

    # compute distributed attention scores
    ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply
    sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)

    # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')

    # check master and distributed output
    sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
    assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)

    # # run master backward
    out.retain_grad()
    out.mean().backward()

    # # run distributed backward
    partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
    torch.autograd.backward(sub_out, partial_master_out_grad)

    # # check master and distributed grads
    partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
    assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
        'attention output cannot match'


# HC
def run_test(rank, world_size):
    colossalai.launch(
        rank=rank,
        world_size=world_size,
        config=CONFIG,
        host='localhost',
        port=29501
    )

    # check_ring_qk(rank, world_size)
    check_ring_av(rank, world_size)

    gpc.destroy()
    torch.cuda.empty_cache()


@pytest.mark.dist
def test_sequence():
    world_size = 4
    run_func = partial(run_test, world_size=world_size)
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_sequence()