test_mbgmv.py 3.58 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pytest
import torch
from torch.nn.utils.rnn import pad_sequence

from lmdeploy.pytorch.kernels.mbgmv import mbgmv_a, mbgmv_b


class TestMBGMV:

    @pytest.fixture
    def dtype(self):
        yield torch.float16

    @pytest.fixture
    def head_size(self):
        yield 64

    @pytest.fixture
    def out_head_size(self):
        yield 32

    @pytest.fixture
    def batch_size(self):
        yield 8

    @pytest.fixture
    def ranks(self):
        yield torch.tensor([2, 4]).cuda()

    @pytest.fixture
    def page_start(self, ranks):
        yield torch.zeros_like(ranks)

    @pytest.fixture
    def input(self, batch_size, head_size, dtype):
        x = torch.rand(batch_size, head_size, dtype=dtype).cuda()
        x -= 0.5
        yield x

    @pytest.fixture
    def adapter_ids(self, batch_size, ranks):
        num_ranks = len(ranks)
        ret = torch.randint(0, num_ranks, (batch_size, )).cuda()
        yield ret

    @pytest.fixture
    def scaling(self, adapter_ids):
        yield torch.ones(adapter_ids.size(0)).cuda()

    @pytest.fixture
    def lora_a(self, ranks, head_size, dtype):
        out = []
        for rank in ranks:
            w = torch.rand(head_size, rank, dtype=dtype).cuda()
            w -= 0.5
            out.append(w)
        yield out

    @pytest.fixture
    def lora_b(self, ranks, out_head_size, dtype):
        out = []
        for rank in ranks:
            w = torch.rand(rank, out_head_size, dtype=dtype).cuda()
            w -= 0.5
            out.append(w)
        yield out

    @pytest.fixture
    def page_table(self, ranks):
        total_ranks = sum(ranks)
        index = torch.randperm(total_ranks)
        index = index.split(ranks.tolist())
        yield pad_sequence(index, batch_first=True).cuda()

    @pytest.fixture
    def paged_lora_a(self, lora_a, ranks, page_table, head_size, dtype):
        num_pages = sum(ranks)
        cache = torch.empty(num_pages, head_size, dtype=dtype).cuda()
        for index, r, w in zip(page_table, ranks, lora_a):
            cache[index[:r]] = w.t()
        yield cache

    @pytest.fixture
    def paged_lora_b(self, lora_b, ranks, page_table, head_size, out_head_size,
                     dtype):
        num_pages = sum(ranks)
        cache = torch.empty(num_pages, head_size, dtype=dtype).cuda()
        for index, r, w in zip(page_table, ranks, lora_b):
            cache[index[:r], :out_head_size] = w
        yield cache

    @pytest.fixture
    def gt(self, input, adapter_ids, lora_a, lora_b):
        out = []
        for inp, r_id in zip(input, adapter_ids):
            inp = inp.unsqueeze(0)
            l_a = lora_a[r_id]
            l_b = lora_b[r_id]
            out.append(inp @ l_a @ l_b)

        yield torch.cat(out)

    def test_mbgmv(self, input, paged_lora_a, paged_lora_b, out_head_size,
                   adapter_ids, scaling, page_table, ranks, page_start, gt):
        max_rank = page_table.size(-1)

        xa = mbgmv_a(input,
                     paged_lora_a,
                     adapter_ids=adapter_ids,
                     rank_page_table=page_table,
                     rank_page_start=page_start,
                     ranks=ranks,
                     max_rank=max_rank)

        output = mbgmv_b(xa,
                         paged_lora_b[..., :out_head_size],
                         adapter_ids=adapter_ids,
                         scaling=scaling,
                         rank_page_table=page_table,
                         rank_page_start=page_start,
                         ranks=ranks,
                         max_rank=max_rank)
        torch.testing.assert_close(gt, output, atol=2e-3, rtol=1e-5)