test_multinomial_sampling.py 1.78 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pytest
import torch

from lmdeploy.pytorch.kernels import multinomial_sampling


class TestMultinomialSampling:

    @pytest.fixture
    def num_tokens(self, request):
        yield request.param

    @pytest.fixture
    def select_ids(self, request):
        yield request.param

    @pytest.fixture
    def batch_size(self, select_ids):
        yield len(select_ids)

    @pytest.fixture
    def dtype(self, request):
        yield request.param

    @pytest.fixture
    def scores(self, num_tokens, batch_size, select_ids, dtype):
        ret = torch.zeros(batch_size, num_tokens).cuda()
        batch_ids = torch.arange(batch_size).cuda()
        ret[batch_ids, select_ids] = 1
        ret = ret.to(dtype)
        yield ret

    @pytest.fixture
    def seeds(self, batch_size):
        yield torch.randint(1000, 2000, (batch_size, )).cuda()

    @pytest.fixture
    def offsets(self, batch_size):
        yield torch.randint(1000, 2000, (batch_size, )).cuda()

    @pytest.fixture
    def indices(self, scores):
        num_tokens = scores.size(1)
        ret = [torch.randperm(num_tokens) for _ in scores]
        ret = torch.stack(ret, 0).cuda()
        yield ret

    @pytest.fixture
    def gt(self, batch_size, select_ids, indices):
        batch_ids = torch.arange(batch_size).cuda()
        yield indices[batch_ids, select_ids]

    @pytest.mark.parametrize('dtype',
                             [torch.float32, torch.half, torch.bfloat16])
    @pytest.mark.parametrize(['num_tokens', 'select_ids'], [
        (8, (4, 2) * 30),
        (200, (50, 150)),
    ],
                             indirect=True)
    def test_multinomial_sampling(self, scores, seeds, offsets, indices, gt):
        output = multinomial_sampling(scores, seeds, offsets, indices)
        torch.testing.assert_close(output, gt)