untest_permute_cols.py 928 Bytes
Newer Older
1
2
3
4
5
import pytest
import torch

from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols
6
from .utils import torch_version
7
8
9
10
11


@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
12
13
14
15
16
17
18
19
20
21
22
23
24
    if torch_version.startswith("2.3"):
        x = torch.randn(shape, dtype=dtype).cuda()
        perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
        y = permute_cols(x, perm)
        torch.allclose(y, x[:, perm])
    elif torch_version.startswith("2.4"):
        x = torch.randn(shape, dtype=dtype).cuda()
        perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
        opcheck(torch.ops._C.permute_cols, (x, perm))
        y = permute_cols(x, perm)
        torch.testing.assert_close(y, x[:, perm])
    else:
        print(f"PyTorch version {torch_version} is not specifically handled.")