Unverified Commit e00715eb authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

[AMD] Add test_fused_moe.py and test_rope_rocm.py to AMD CI (#5246)

parent ea4bf122
...@@ -125,6 +125,7 @@ suites = { ...@@ -125,6 +125,7 @@ suites = {
TestFile("test_chunked_prefill.py", 313), TestFile("test_chunked_prefill.py", 313),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_function_call_parser.py", 10), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
TestFile("test_metrics.py", 32), TestFile("test_metrics.py", 32),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
...@@ -142,6 +143,7 @@ suites = { ...@@ -142,6 +143,7 @@ suites = {
TestFile("test_vertex_endpoint.py", 31), TestFile("test_vertex_endpoint.py", 31),
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
TestFile("test_rope_rocm.py", 3),
], ],
"per-commit-npu": [ "per-commit-npu": [
TestFile("test_ascend_attention_backend.py", 400), TestFile("test_ascend_attention_backend.py", 400),
......
...@@ -6,8 +6,14 @@ from tqdm import tqdm ...@@ -6,8 +6,14 @@ from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import is_hip
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
class TestFusedMOE(CustomTestCase): class TestFusedMOE(CustomTestCase):
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
...@@ -64,7 +70,7 @@ class TestFusedMOE(CustomTestCase): ...@@ -64,7 +70,7 @@ class TestFusedMOE(CustomTestCase):
topk_weight = topk_weight.view(-1) topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
if w1.dtype == torch.float8_e4m3fn: if w1.dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]:
w1_compute = w1.to(a.dtype) w1_compute = w1.to(a.dtype)
w2_compute = w2.to(a.dtype) w2_compute = w2.to(a.dtype)
...@@ -97,7 +103,7 @@ class TestFusedMOE(CustomTestCase): ...@@ -97,7 +103,7 @@ class TestFusedMOE(CustomTestCase):
if use_fp8_w8a8: if use_fp8_w8a8:
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
if not (capability[0] >= 9 or capability == (8, 9)): if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)):
return return
a = self.create_random_cuda_tensor((m, k), dtype) a = self.create_random_cuda_tensor((m, k), dtype)
...@@ -106,12 +112,26 @@ class TestFusedMOE(CustomTestCase): ...@@ -106,12 +112,26 @@ class TestFusedMOE(CustomTestCase):
w1 = w1.to(torch.float8_e4m3fn) w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn)
score = self.create_random_cuda_tensor((m, e), dtype) score = self.create_random_cuda_tensor((m, e), dtype)
w1_scale = self.create_random_cuda_tensor(e, torch.float32) w1_scale = self.create_random_cuda_tensor(e, torch.float32)
w2_scale = self.create_random_cuda_tensor(e, torch.float32) w2_scale = self.create_random_cuda_tensor(e, torch.float32)
a1_scale = self.create_random_cuda_tensor(1, torch.float32) a1_scale = self.create_random_cuda_tensor(1, torch.float32)
a2_scale = self.create_random_cuda_tensor(1, torch.float32) a2_scale = self.create_random_cuda_tensor(1, torch.float32)
# Handle HIP case: normalize float8 weights so fused kernel doesn't break
# on ROCm.
if _is_fp8_fnuz:
# Normalize to e4m3fnuz on HIP
w1, w1_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w1,
weight_scale=w1_scale,
input_scale=a1_scale,
)
w2, w2_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w2,
weight_scale=w2_scale,
input_scale=a2_scale,
)
sglang_output = fused_moe( sglang_output = fused_moe(
a, a,
w1, w1,
...@@ -127,12 +147,19 @@ class TestFusedMOE(CustomTestCase): ...@@ -127,12 +147,19 @@ class TestFusedMOE(CustomTestCase):
) )
torch_output = self.torch_naive_moe( torch_output = self.torch_naive_moe(
a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale a,
w1,
w2,
score,
topk,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
) )
torch.testing.assert_close( torch.testing.assert_close(
sglang_output, torch_output, rtol=rtol, atol=atol sglang_output, torch_output, rtol=rtol, atol=atol
) )
else: else:
a = self.create_random_cuda_tensor((m, k), dtype) a = self.create_random_cuda_tensor((m, k), dtype)
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
......
import unittest
import torch
from sglang.srt.layers.rotary_embedding import RotaryEmbedding
from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(0)
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_CASES = [
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2),
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2),
]
@unittest.skipIf(_use_aiter, reason="SGLANG_USE_AITER=1 will not use vllm path.")
class TestRotaryEmbeddingNative(CustomTestCase):
# Compare RotaryEmbedding.forward_hip() to forward_native().
def _run_case(
self,
head_size: int,
rotary_dim: int,
max_pos: int,
base: int,
is_neox: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q: int,
num_kv: int,
) -> None:
rope_ref = RotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
rope_hip = RotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device
)
q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone())
q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone())
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
def test_all_cases(self) -> None:
"""Drive over the full parameter matrix using subTest()."""
for case in _CASES:
with self.subTest(case=case):
self._run_case(*case)
@unittest.skipIf(not _use_aiter, reason="Requires AMD GPU plus SGLANG_USE_AITER=1")
class TestRotaryEmbeddingAITer(CustomTestCase):
@staticmethod
def _run_case_aiter(
head_size: int,
rotary_dim: int,
max_pos: int,
base: int,
is_neox: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q: int,
num_kv: int,
) -> None:
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
rope_ref = AiterRotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
rope_hip = AiterRotaryEmbedding(
head_size, rotary_dim, max_pos, base, is_neox, dtype
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device
)
q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone())
q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone())
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
def test_all_cases(self) -> None:
for case in _CASES:
with self.subTest(case=case):
self._run_case_aiter(*case)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment