"tests/vscode:/vscode.git/clone" did not exist on "11d8e3ce2c0b7789a173bbf9e8fcc42b7c7e3cf6"
Unverified Commit 094c116f authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)


Co-authored-by: default avatarWu, Chunyuan <chunyuan.wu@intel.com>
Co-authored-by: default avatarjianan-gu <jianan.gu@intel.com>
Co-authored-by: default avatarsdp <sdp@gnr799219.jf.intel.com>
parent e56685ac
......@@ -39,6 +39,7 @@ RUN git clone https://github.com/sgl-project/sglang.git && \
cp pyproject_cpu.toml pyproject.toml && \
pip install -v .
ENV SGLANG_USE_CPU_ENGINE=1
ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2
WORKDIR /sgl-workspace/sglang
from torch import nn
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
class CustomOp(nn.Module):
......@@ -75,5 +77,7 @@ class CustomOp(nn.Module):
return self.forward_cuda
elif _is_hip:
return self.forward_hip
elif _is_cpu and _is_cpu_amx_available:
return self.forward_cpu
else:
return self.forward_native
......@@ -29,11 +29,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_npu,
set_weight_attrs,
)
from sglang.utils import resolve_obj_by_qualname
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
......@@ -53,6 +61,15 @@ class SiluAndMul(CustomOp):
silu_and_mul(x, out)
return out
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if _is_cpu_amx_available:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
return out
else:
return self.forward_native(x)
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
......@@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity()
if not _is_cuda and not _is_npu:
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
......@@ -20,12 +20,21 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import (
......@@ -122,6 +131,23 @@ class RMSNorm(CustomOp):
else:
return x, residual
def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if _is_cpu_amx_available:
if residual is not None:
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
return torch.ops.sgl_kernel.rmsnorm_cpu(
x, self.weight.data, self.variance_epsilon
)
else:
return self.forward_native(x, residual)
class GemmaRMSNorm(CustomOp):
def __init__(
......@@ -188,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not (_is_cuda or _is_hip or _is_npu):
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
)
......
......@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8,
)
from sglang.srt.utils import (
cpu_has_amx_support,
direct_register_custom_op,
get_bool_env_var,
get_device_name,
is_cpu,
is_cuda,
is_hip,
log_info_on_rank0,
......@@ -36,9 +38,13 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
elif _is_cpu and _is_cpu_amx_available:
pass
else:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -241,7 +241,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
......@@ -260,7 +264,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cuda
forward_native = forward_cpu
class FusedMoE(torch.nn.Module):
......
......@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
from sglang.srt.utils import (
cpu_has_amx_support,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import moe_fused_gate
......@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
def fused_topk_native(
def fused_topk_torch_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
......@@ -61,6 +69,20 @@ def fused_topk_native(
return topk_weights, topk_ids
def fused_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
return torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
# This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk(
def grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
......@@ -171,6 +193,32 @@ def grouped_topk(
return topk_weights, topk_ids
def grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
return topk_ids
def biased_grouped_topk(
def biased_grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
......@@ -322,6 +370,45 @@ def biased_grouped_topk(
)
def biased_grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu
fused_topk_native = fused_topk_cpu
else:
biased_grouped_topk = biased_grouped_topk_gpu
grouped_topk = grouped_topk_gpu
fused_topk_native = fused_topk_torch_native
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......
......@@ -14,15 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
all_close_1d,
cpu_has_amx_support,
per_tensor_dequantize,
replace_parameter,
)
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if not _is_cuda and not _is_npu:
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -64,7 +64,9 @@ from sglang.srt.layers.quantization.utils import (
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
......@@ -76,6 +78,8 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_fp8_fnuz = is_fp8_fnuz()
......@@ -88,7 +92,7 @@ if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
if not _is_cuda and not _is_npu:
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -6,12 +6,14 @@ from typing import List, Mapping, Tuple, Union
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda, is_npu
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if not _is_cuda and not _is_npu:
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -8,11 +8,13 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip, is_npu
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
......@@ -85,7 +87,9 @@ class RotaryEmbedding(CustomOp):
if not _is_cuda:
cache = cache.to(dtype)
if not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]:
if (
not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
) and not (_is_cpu and _is_cpu_amx_available):
from vllm._custom_ops import rotary_embedding
self.vllm_rotary_embedding = rotary_embedding
......@@ -148,6 +152,26 @@ class RotaryEmbedding(CustomOp):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
else:
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
positions: torch.Tensor,
......@@ -697,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key = key_rot
return query.to(dtype), key.to(dtype)
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
positions, query, key, self.head_size, self.cos_sin_cache, False
)
else:
return self.forward_native(positions, query, key, offsets)
class Llama3RotaryEmbedding(RotaryEmbedding):
......
......@@ -111,6 +111,7 @@ from sglang.srt.utils import (
)
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
# Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
......@@ -302,7 +303,7 @@ class ModelRunner:
if (
server_args.attention_backend == "intel_amx"
and server_args.device == "cpu"
and not cpu_has_amx_support()
and not _is_cpu_amx_available
):
logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
......
......@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -95,8 +95,10 @@ from sglang.srt.utils import (
LazyValue,
add_prefix,
bind_or_assign,
cpu_has_amx_support,
get_bool_env_var,
get_int_env_var,
is_cpu,
is_cuda,
is_hip,
is_non_idle_and_non_empty,
......@@ -107,9 +109,13 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
elif _is_cpu and _is_cpu_amx_available:
pass
else:
from vllm._custom_ops import awq_dequantize
......@@ -665,13 +671,14 @@ class DeepseekV2AttentionMLA(nn.Module):
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
if rope_scaling:
......
......@@ -160,7 +160,7 @@ def is_npu() -> bool:
return hasattr(torch, "npu") and torch.npu.is_available()
def is_cpu() -> bool:
def is_host_cpu_x86() -> bool:
machine = platform.machine().lower()
return (
machine in ("x86_64", "amd64", "i386", "i686")
......@@ -169,6 +169,10 @@ def is_cpu() -> bool:
)
def is_cpu() -> bool:
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
def is_flashinfer_available():
"""
Check whether flashinfer is available.
......@@ -1452,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str:
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
)
if is_cpu():
if cpu_has_amx_support():
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
else:
logger.warning(
"CPU device enabled, using torch native backend, low performance expected."
)
return "cpu"
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
......
......@@ -21,7 +21,7 @@ class TestActivation(CustomTestCase):
ref_out = SiluAndMul(x)
atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype):
......
......@@ -60,8 +60,8 @@ class TestGemm(CustomTestCase):
)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol))
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
def test_bf16_gemm(self):
for params in itertools.product(
......@@ -100,13 +100,13 @@ class TestGemm(CustomTestCase):
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
# test the fused version
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
def test_int8_gemm(self):
for params in itertools.product(
......@@ -165,7 +165,7 @@ class TestGemm(CustomTestCase):
prepack,
)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol))
torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
def test_fp8_gemm(self):
for params in itertools.product(
......
......@@ -91,9 +91,7 @@ class TestFusedExperts(CustomTestCase):
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
atol = rtol = precision[torch_output.dtype]
self.assertTrue(
torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol)
)
torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
def test_bf16_moe(self):
for params in itertools.product(
......@@ -171,7 +169,7 @@ class TestFusedExperts(CustomTestCase):
# Increase the tolerance for large input shapes
if M > 35:
atol = rtol = 0.02
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_int8_moe(self):
for params in itertools.product(
......@@ -235,7 +233,7 @@ class TestFusedExperts(CustomTestCase):
)
atol = rtol = precision[dtype]
self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
def test_fp8_moe(self):
for params in itertools.product(
......
......@@ -47,7 +47,7 @@ class TestNorm(CustomTestCase):
ref_out = self._forward_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
......@@ -61,8 +61,8 @@ class TestNorm(CustomTestCase):
ref_x, weight, variance_epsilon, ref_residual
)
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _l2norm_test(self, m, n, dtype):
......@@ -75,7 +75,7 @@ class TestNorm(CustomTestCase):
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
......
......@@ -211,12 +211,12 @@ class TestQKVProjWithROPE(CustomTestCase):
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fused_q_out, q_out))
self.assertTrue(torch.allclose(fused_k_out, k_out))
self.assertTrue(torch.allclose(fused_v_out, v_out))
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
def test_int8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
......@@ -302,12 +302,12 @@ class TestQKVProjWithROPE(CustomTestCase):
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fused_q_out, q_out))
self.assertTrue(torch.allclose(fused_k_out, k_out))
self.assertTrue(torch.allclose(fused_v_out, v_out))
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
def test_fp8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
......
......@@ -75,8 +75,8 @@ class TestROPE(CustomTestCase):
)
atol = rtol = precision[q_pe.dtype]
self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment