"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "c2c4f57f6311ba143c6156ab1d1a1d9413e6e4d0"
test_rearange_all_gather.py 2.74 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import pytest
import torch

from lmdeploy.pytorch.kernels.rearange_all_gather import rearange_all_gather


class TestRearangeAllGather:

    @pytest.fixture
    def seq_lens(self, request):
        yield torch.tensor(request.param, device='cuda')

    @pytest.fixture
    def start_loc(self, seq_lens):
        yield seq_lens.cumsum(0) - seq_lens

    @pytest.fixture
    def ranks(self):
        yield torch.tensor([4, 8]).cuda()

    @pytest.fixture
    def adapter_ids(self, seq_lens, ranks):
        num_ranks = len(ranks)
        num_seqs = len(seq_lens)
        ret = torch.randint(0, num_ranks, (num_seqs, )).cuda()
        yield ret

    @pytest.fixture
    def world_size(self):
        yield 2

    @pytest.fixture
    def input(self, seq_lens, ranks):
        max_rank = max(ranks)
        total_len = seq_lens.sum()
        yield torch.rand(total_len, max_rank).cuda()

    @pytest.fixture
    def rank_per_input(self, seq_lens, ranks, adapter_ids):
        token_adapter_ids = [
            torch.full((slen, ), ada_id)
            for slen, ada_id in zip(seq_lens, adapter_ids)
        ]
        token_adapter_ids = torch.cat(token_adapter_ids).cuda()
        yield ranks[token_adapter_ids]

    @pytest.fixture
    def valid_mask(self, rank_per_input, seq_lens, ranks):
        max_rank = max(ranks)
        total_len = seq_lens.sum()
        mask = torch.zeros(total_len, max_rank).to(bool)
        for r, m in zip(rank_per_input, mask):
            m[:r] = True
        yield mask.cuda()

    @pytest.fixture
    def gt(self, input, rank_per_input, ranks, world_size):
        max_rank = max(ranks)
        pranks = rank_per_input // world_size
        pmax_rank = max_rank // world_size
        output = torch.empty_like(input)
        for pr, inp, out in zip(pranks, input, output):
            pindex = torch.arange(pr).cuda()
            index = [pindex + ws * pmax_rank for ws in range(world_size)]
            index = torch.cat(index)
            out[:index.size(0)] = inp[index]
        yield output

    @pytest.mark.parametrize('seq_lens', [[30, 50, 70, 90], [1, 1, 1, 1]],
                             indirect=True)
    def test_gather(self, input, start_loc, seq_lens, adapter_ids, ranks,
                    world_size, gt, valid_mask):
        max_seq_len = max(seq_lens)
        output = rearange_all_gather(input,
                                     start_loc,
                                     seq_lens,
                                     adapter_ids,
                                     ranks,
                                     world_size,
                                     max_seq_len=max_seq_len)
        output = output.where(valid_mask, output.new_tensor(0))
        gt = gt.where(valid_mask, gt.new_tensor(0))
        torch.testing.assert_close(output, gt)