Unverified Commit 251c18d1 authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

skip fp8e4b15 on xpu (#39957)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 512765d5
...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.turboquant.config import ( ...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.turboquant.config import (
from vllm.model_executor.layers.quantization.turboquant.quantizer import ( from vllm.model_executor.layers.quantization.turboquant.quantizer import (
generate_wht_signs, generate_wht_signs,
) )
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
# ============================================================================ # ============================================================================
...@@ -345,7 +346,8 @@ class TestLloydMax: ...@@ -345,7 +346,8 @@ class TestLloydMax:
# Rotation matrix tests (GPU required) # Rotation matrix tests (GPU required)
# ============================================================================ # ============================================================================
CUDA_AVAILABLE = torch.cuda.is_available() GPGPU_AVAILABLE = torch.cuda.is_available() or torch.xpu.is_available()
DEVICE_TYPE = current_platform.device_type
def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Tensor: def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Tensor:
...@@ -360,16 +362,16 @@ def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Te ...@@ -360,16 +362,16 @@ def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Te
return Q.to(device) return Q.to(device)
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available") @pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available")
class TestRotationMatrix: class TestRotationMatrix:
"""Tests for the QR-based rotation (standalone benchmarks only).""" """Tests for the QR-based rotation (standalone benchmarks only)."""
@pytest.mark.parametrize("dim", [64, 96, 128, 256]) @pytest.mark.parametrize("dim", [64, 96, 128, 256])
def test_rotation_matrix_shape_and_orthogonal(self, dim): def test_rotation_matrix_shape_and_orthogonal(self, dim):
Pi = generate_rotation_matrix(dim, seed=42, device="cuda") Pi = generate_rotation_matrix(dim, seed=42, device=DEVICE_TYPE)
assert Pi.shape == (dim, dim) assert Pi.shape == (dim, dim)
eye = Pi @ Pi.T eye = Pi @ Pi.T
assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), ( assert torch.allclose(eye, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"Pi not orthogonal for dim={dim}" f"Pi not orthogonal for dim={dim}"
) )
...@@ -385,7 +387,7 @@ class TestRotationMatrix: ...@@ -385,7 +387,7 @@ class TestRotationMatrix:
def test_rotation_matrix_det_is_pm1(self): def test_rotation_matrix_det_is_pm1(self):
"""Orthogonal matrix determinant must be +1 or -1.""" """Orthogonal matrix determinant must be +1 or -1."""
Pi = generate_rotation_matrix(128, seed=42, device="cuda") Pi = generate_rotation_matrix(128, seed=42, device=DEVICE_TYPE)
det = torch.linalg.det(Pi) det = torch.linalg.det(Pi)
assert abs(abs(det.item()) - 1.0) < 1e-4 assert abs(abs(det.item()) - 1.0) < 1e-4
...@@ -403,31 +405,31 @@ def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor: ...@@ -403,31 +405,31 @@ def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor:
return (H / math.sqrt(d)).to(torch.device(device)) return (H / math.sqrt(d)).to(torch.device(device))
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available") @pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available")
class TestWHTRotation: class TestWHTRotation:
"""Tests for the WHT rotation actually used in serving.""" """Tests for the WHT rotation actually used in serving."""
@pytest.mark.parametrize("dim", [64, 128, 256]) @pytest.mark.parametrize("dim", [64, 128, 256])
def test_wht_orthonormal(self, dim): def test_wht_orthonormal(self, dim):
"""signs * H must be orthonormal: (signs*H) @ (signs*H)^T = I.""" """signs * H must be orthonormal: (signs*H) @ (signs*H)^T = I."""
signs = generate_wht_signs(dim, seed=42, device="cuda") signs = generate_wht_signs(dim, seed=42, device=DEVICE_TYPE)
H = _build_hadamard(dim, "cuda") H = _build_hadamard(dim, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous() PiT = (signs.unsqueeze(1) * H).contiguous()
eye = PiT @ PiT.T eye = PiT @ PiT.T
assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), ( assert torch.allclose(eye, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"WHT rotation not orthonormal for dim={dim}" f"WHT rotation not orthonormal for dim={dim}"
) )
@pytest.mark.parametrize("dim", [64, 128, 256]) @pytest.mark.parametrize("dim", [64, 128, 256])
def test_wht_self_inverse(self, dim): def test_wht_self_inverse(self, dim):
"""PiT should be self-inverse: PiT @ PiT = I (up to sign flip).""" """PiT should be self-inverse: PiT @ PiT = I (up to sign flip)."""
signs = generate_wht_signs(dim, seed=42, device="cuda") signs = generate_wht_signs(dim, seed=42, device=DEVICE_TYPE)
H = _build_hadamard(dim, "cuda") H = _build_hadamard(dim, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous() PiT = (signs.unsqueeze(1) * H).contiguous()
Pi = PiT.T.contiguous() Pi = PiT.T.contiguous()
# Pi @ PiT should be identity (rotation then inverse) # Pi @ PiT should be identity (rotation then inverse)
result = Pi @ PiT result = Pi @ PiT
assert torch.allclose(result, torch.eye(dim, device="cuda"), atol=1e-5), ( assert torch.allclose(result, torch.eye(dim, device=DEVICE_TYPE), atol=1e-5), (
f"WHT rotation not self-inverse for dim={dim}" f"WHT rotation not self-inverse for dim={dim}"
) )
...@@ -454,7 +456,7 @@ class TestWHTRotation: ...@@ -454,7 +456,7 @@ class TestWHTRotation:
# ============================================================================ # ============================================================================
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available") @pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available")
class TestStoreDecodeRoundTrip: class TestStoreDecodeRoundTrip:
"""End-to-end: store KV into TQ cache, decode, compare vs fp16 ref.""" """End-to-end: store KV into TQ cache, decode, compare vs fp16 ref."""
...@@ -487,11 +489,11 @@ class TestStoreDecodeRoundTrip: ...@@ -487,11 +489,11 @@ class TestStoreDecodeRoundTrip:
block_size = 16 block_size = 16
num_blocks = 1 num_blocks = 1
device = torch.device("cuda") device = torch.device(DEVICE_TYPE)
# Generate rotation # Generate rotation
signs = generate_wht_signs(D, seed=42, device=device) signs = generate_wht_signs(D, seed=42, device=device)
H = _build_hadamard(D, "cuda") H = _build_hadamard(D, DEVICE_TYPE)
PiT = (signs.unsqueeze(1) * H).contiguous().float() PiT = (signs.unsqueeze(1) * H).contiguous().float()
Pi = PiT.T.contiguous() Pi = PiT.T.contiguous()
......
...@@ -13,6 +13,7 @@ from typing import Any ...@@ -13,6 +13,7 @@ from typing import Any
import torch import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.ops.triton_decode_attention import ( from vllm.v1.attention.ops.triton_decode_attention import (
_fwd_kernel_stage2, _fwd_kernel_stage2,
...@@ -22,10 +23,15 @@ _FP8_E4B15: dict[int, int] = {} ...@@ -22,10 +23,15 @@ _FP8_E4B15: dict[int, int] = {}
def _use_fp8_e4b15(device: int = 0) -> int: def _use_fp8_e4b15(device: int = 0) -> int:
"""Return 1 if device needs fp8e4b15 (Ampere/Ada, SM < 8.9), else 0.""" """Return 1 if device needs fp8e4b15 (Ampere/Ada, SM < 8.9), else 0.
On non-CUDA platforms (e.g. XPU), always returns 0 (use e4nv format).
"""
if device not in _FP8_E4B15: if device not in _FP8_E4B15:
cap = torch.cuda.get_device_capability(device) if current_platform.is_cuda_alike():
_FP8_E4B15[device] = 1 if cap < (8, 9) else 0 cap = torch.cuda.get_device_capability(device)
_FP8_E4B15[device] = 1 if cap < (8, 9) else 0
else:
_FP8_E4B15[device] = 0
return _FP8_E4B15[device] return _FP8_E4B15[device]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment