import pytest import torch from tests.kernels.utils import opcheck from vllm._custom_ops import permute_cols from .utils import torch_version @pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) def test_permute_cols(shape, dtype): if torch_version.startswith("2.3"): x = torch.randn(shape, dtype=dtype).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda() y = permute_cols(x, perm) torch.allclose(y, x[:, perm]) elif torch_version.startswith("2.4"): x = torch.randn(shape, dtype=dtype).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda() opcheck(torch.ops._C.permute_cols, (x, perm)) y = permute_cols(x, perm) torch.testing.assert_close(y, x[:, perm]) else: print(f"PyTorch version {torch_version} is not specifically handled.")