test_per_token_quant_fp8.py 1.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import itertools
from typing import Optional, Tuple

import pytest
import torch
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops

from sglang.srt.utils import is_hip

is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn


def vllm_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)


def sglang_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
Yineng Zhang's avatar
Yineng Zhang committed
24
    scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
25
26
27
    output = torch.empty_like(input, device=input.device, dtype=fp8_type_)

    sgl_per_token_quant_fp8(input, output, scale)
Yineng Zhang's avatar
Yineng Zhang committed
28
29
    scale = scale.reshape(-1, 1)

30
31
32
33
34
    return output, scale


@pytest.mark.parametrize(
    "num_tokens,hidden_dim",
Yineng Zhang's avatar
Yineng Zhang committed
35
    list(itertools.product([128, 256, 512], [512, 2048, 4096])),
36
37
38
39
40
41
42
43
44
45
46
)
def test_per_token_quant_compare_implementations(
    num_tokens: int,
    hidden_dim: int,
):
    device = torch.device("cuda")
    x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)

    vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
    sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)

Yineng Zhang's avatar
Yineng Zhang committed
47
    torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
48
49
50
51
52
53
54
55
    torch.testing.assert_close(
        vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
    )


if __name__ == "__main__":
    # Run the specific test function directly
    pytest.main([__file__])