"vscode:/vscode.git/clone" did not exist on "f44afef6d61faa40be65c14800307f42fa64ca55"
Unverified Commit 7d6917be authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm] Fix MoE kernel test failures on gfx950 (#37833)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
Signed-off-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
parent e38817fa
...@@ -32,6 +32,14 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -32,6 +32,14 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.utils.import_utils import ( from vllm.utils.import_utils import (
has_aiter, has_aiter,
has_deep_ep, has_deep_ep,
...@@ -152,6 +160,39 @@ class Config: ...@@ -152,6 +160,39 @@ class Config:
return vllm_config, env_dict return vllm_config, env_dict
def fe_supports_quant_scheme(self) -> bool:
"""Check if the fused experts class supports this quant config.
See https://github.com/ROCm/aiter/issues/2419 for AITER gaps."""
if self.quant_config is None or self.quant_dtype is None:
return True
if self.quant_dtype != torch.float8_e4m3fn:
return True
# Derive QuantKeys from test config
if self.quant_block_shape is not None:
w_key = kFp8Static128BlockSym
a_key = kFp8Dynamic128Sym
elif self.is_per_out_ch_quant:
w_key = kFp8StaticChannelSym
a_key = (
kFp8DynamicTokenSym
if self.is_per_act_token_quant
else kFp8StaticTensorSym
)
else:
w_key = kFp8StaticTensorSym
a_key = (
kFp8DynamicTensorSym
if self.is_per_act_token_quant
else kFp8StaticTensorSym
)
fe_cls = self.fused_experts_type
if hasattr(fe_cls, "_supports_quant_scheme"):
try:
return fe_cls._supports_quant_scheme(w_key, a_key)
except NotImplementedError:
pass
return True
def is_fp8_block_quantized(self): def is_fp8_block_quantized(self):
return ( return (
self.quant_dtype == torch.float8_e4m3fn self.quant_dtype == torch.float8_e4m3fn
...@@ -253,6 +294,15 @@ class Config: ...@@ -253,6 +294,15 @@ class Config:
f"{self.fe_supported_types()}." f"{self.fe_supported_types()}."
) )
# Check quant scheme compatibility with fused experts class
if not self.fe_supports_quant_scheme():
return False, (
f"FE {self.fused_experts_type.__name__} does not support "
f"quant scheme (per_out_ch={self.is_per_out_ch_quant}, "
f"per_act_token={self.is_per_act_token_quant}, "
f"block={self.quant_block_shape})"
)
# Check block quantization support # Check block quantization support
is_block_quantized = self.quant_block_shape is not None is_block_quantized = self.quant_block_shape is not None
if is_block_quantized and self.quant_dtype is None: if is_block_quantized and self.quant_dtype is None:
......
...@@ -384,9 +384,18 @@ def test_legacy_routing( ...@@ -384,9 +384,18 @@ def test_legacy_routing(
logits = gating_output logits = gating_output
if sm_first: if sm_first:
logits = torch.softmax(logits, dim=-1) logits = torch.softmax(logits, dim=-1)
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first) topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
topk_ids = sparse_logits.indx.to(torch.long) # topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
topk_weights = sparse_logits.vals if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, bitmatrix = topk_result
from triton_kernels.routing import routing_from_bitmatrix
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
)
else:
topk_ids = topk_result.indx.to(torch.long)
topk_weights = topk_result.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data( routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts topk_ids, topk_weights, num_experts
) )
......
...@@ -108,6 +108,23 @@ def rank_worker( ...@@ -108,6 +108,23 @@ def rank_worker(
# inputs for rank # inputs for rank
rank_tensors = RankTensors.make(config, pgi) rank_tensors = RankTensors.make(config, pgi)
# Skip unsupported: AITER block-scaled MoE does not
# support apply_router_weight_on_input (topk=1 path).
# https://github.com/ROCm/aiter/issues/2418
if (
topk == 1
and config.supports_apply_weight_on_input()
and getattr(config.fused_experts_type, "__name__", "") == "AiterExperts"
and config.quant_block_shape is not None
):
print(
f"Skipping[{pgi.rank}]: m={m}, topk={topk}"
" (AITER block-scaled + weight-on-input,"
" https://github.com/ROCm/aiter/issues/2418)"
)
count -= 1
continue
# modular kernel out # modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors) mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
...@@ -121,6 +138,47 @@ def rank_worker( ...@@ -121,6 +138,47 @@ def rank_worker(
atol = 3e-2 atol = 3e-2
rtol = 3e-2 rtol = 3e-2
# On ROCm, AITER FP8 fused MoE uses hardware FP8
# dot-product which can produce slightly larger error
# than dequant+f32 matmul at FP8 representable-value
# boundaries. Allow a small percentage of elements to
# exceed the base tolerance by a bounded margin.
# https://github.com/ROCm/aiter/issues/2421
from vllm.platforms import current_platform as _cp
is_aiter_fp8 = (
_cp.is_rocm()
and getattr(config.fused_experts_type, "__name__", "") == "AiterExperts"
and config.quant_config is not None
)
if is_aiter_fp8:
diff = (ref_out - mk_out).abs()
n_total = diff.numel()
max_diff = diff.max().item()
n_exceed = int((diff > atol).sum().item())
pct_exceed = n_exceed / n_total * 100
# FP8 hw matmul vs f32 reference: up to ~4% of
# elements may exceed base tolerance, but max
# error should stay within 3x base tolerance.
max_pct_allowed = 5.0
relaxed_atol = atol * 4
print(
f"[AITER FP8 precision] "
f"max_diff={max_diff:.6f}, "
f"exceed_atol={n_exceed}/{n_total} "
f"({pct_exceed:.4f}%), "
f"max_pct_allowed={max_pct_allowed}%, "
f"relaxed_limit={relaxed_atol}"
)
assert pct_exceed <= max_pct_allowed, (
f"AITER FP8: {pct_exceed:.2f}% elements exceed "
f"atol={atol} (max allowed {max_pct_allowed}%)"
)
assert max_diff <= relaxed_atol, (
f"AITER FP8: max_diff={max_diff:.6f} exceeds "
f"relaxed limit {relaxed_atol}"
)
else:
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol) torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
format_result(verbose, config.describe()) format_result(verbose, config.describe())
except Exception as ex: except Exception as ex:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.router_factory import ( from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router, create_fused_moe_router,
) )
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
def _is_aiter_capable() -> bool:
"""Check if the platform supports AITER (gfx942/gfx950)."""
if not current_platform.is_rocm():
return False
try:
from vllm.platforms.rocm import _ON_MI3XX
return _ON_MI3XX
except ImportError:
return False
# Test parameters # Test parameters
MK_S = [(32, 256), (64, 512)] MK_S = [(32, 256), (64, 512)]
...@@ -96,6 +112,60 @@ def assert_routing_results_close( ...@@ -96,6 +112,60 @@ def assert_routing_results_close(
) )
def assert_aiter_routing_valid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
num_experts: int,
renormalize: bool,
routed_scaling_factor: float = 1.0,
):
"""Validate AITER routing outputs are structurally correct.
AITER grouped_topk is a fundamentally different implementation from
the Python baseline (different group selection, scoring internals),
so numerical comparison is not meaningful. Instead we verify the
outputs satisfy the routing contract: correct shapes, valid expert
IDs, non-negative weights, and proper normalization."""
n_tokens = topk_weights.shape[0]
# Shape
assert topk_weights.shape == (n_tokens, top_k), (
f"weights shape {topk_weights.shape} != ({n_tokens}, {top_k})"
)
assert topk_ids.shape == (n_tokens, top_k), (
f"ids shape {topk_ids.shape} != ({n_tokens}, {top_k})"
)
# Expert IDs in valid range
assert (topk_ids >= 0).all() and (topk_ids < num_experts).all(), (
f"expert IDs out of range [0, {num_experts}): "
f"min={topk_ids.min().item()}, max={topk_ids.max().item()}"
)
# No duplicate expert IDs per token
for i in range(n_tokens):
ids = topk_ids[i]
assert ids.unique().numel() == top_k, (
f"token {i}: duplicate expert IDs {ids.tolist()}"
)
# Weights are non-negative
assert (topk_weights >= 0).all(), "negative routing weights"
# If renormalized, weights should sum to ~scaling_factor per token
# (renormalization to 1.0 happens before scaling)
if renormalize:
expected_sum = routed_scaling_factor
sums = topk_weights.sum(dim=-1)
torch.testing.assert_close(
sums,
torch.full_like(sums, expected_sum),
atol=1e-3,
rtol=1e-3,
)
def baseline_fused_topk( def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool router_logits: torch.Tensor, top_k: int, renormalize: bool
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -400,10 +470,7 @@ def test_grouped_topk( ...@@ -400,10 +470,7 @@ def test_grouped_topk(
hidden_states, router_logits = make_test_data(m, k, global_num_experts) hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output # Compute baseline (pure Python implementation)
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_grouped_topk( baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits, router_logits,
top_k, top_k,
...@@ -415,8 +482,32 @@ def test_grouped_topk( ...@@ -415,8 +482,32 @@ def test_grouped_topk(
routed_scaling_factor, routed_scaling_factor,
) )
# Compare results # Test 1: Python/Triton path against baseline (exact match)
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids) with patch(
"vllm.model_executor.layers.fused_moe.router.grouped_topk_router.rocm_aiter_ops.is_fused_moe_enabled",
return_value=False,
):
py_weights, py_ids = router.select_experts(hidden_states, router_logits)
assert_routing_results_close(py_weights, py_ids, baseline_weights, baseline_ids)
# Test 2: AITER path — verify outputs are structurally valid.
# AITER grouped_topk is a different implementation so we can't
# compare numerically against the Python baseline.
if _is_aiter_capable():
# Force-enable AITER for gfx942/gfx950 regardless of env var,
# so CI always exercises this path on capable hardware.
with patch.object(rocm_aiter_ops, "_AITER_ENABLED", True):
aiter_weights, aiter_ids = router.select_experts(
hidden_states, router_logits
)
assert_aiter_routing_valid(
aiter_weights,
aiter_ids,
top_k,
global_num_experts,
renormalize,
routed_scaling_factor,
)
@pytest.mark.parametrize("m,k", MK_S) @pytest.mark.parametrize("m,k", MK_S)
......
...@@ -14,6 +14,7 @@ import torch.nn as nn ...@@ -14,6 +14,7 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -51,6 +52,60 @@ class SimpleSharedExperts(nn.Module): ...@@ -51,6 +52,60 @@ class SimpleSharedExperts(nn.Module):
return self.down(nn.functional.silu(gate) * up) return self.down(nn.functional.silu(gate) * up)
def _assert_close(
actual: torch.Tensor,
expected: torch.Tensor,
atol: float,
rtol: float,
label: str,
) -> None:
"""assert_close that prints diff diagnostics on both success and failure."""
actual_nans = int(actual.isnan().sum().item())
expected_nans = int(expected.isnan().sum().item())
actual_zeros = int((actual == 0).sum().item())
expected_zeros = int((expected == 0).sum().item())
n_total = actual.numel()
diff = (actual - expected).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
n_exceed = int((diff > atol).sum().item())
pct_exceed = n_exceed / n_total * 100
print(
f"[{label}] "
f"shape={list(actual.shape)}, "
f"max_diff={max_diff:.6e}, "
f"mean_diff={mean_diff:.6e}, "
f"exceed_atol({atol})={n_exceed}/{n_total} ({pct_exceed:.2f}%), "
f"actual=[{actual.min().item():.4f}, {actual.max().item():.4f}], "
f"expected=[{expected.min().item():.4f}, {expected.max().item():.4f}], "
f"nan(actual/expected)={actual_nans}/{expected_nans}, "
f"zeros(actual/expected)={actual_zeros}/{expected_zeros}"
)
assert actual_nans == 0, (
f"{label}: actual has {actual_nans}/{n_total} NaN values "
f"(expected has {expected_nans}). "
f"This indicates a kernel bug, not a precision issue."
)
assert expected_nans == 0, (
f"{label}: expected has {expected_nans}/{n_total} NaN values. "
f"This indicates a kernel bug, not a precision issue."
)
torch.testing.assert_close(
actual,
expected,
atol=atol,
rtol=rtol,
msg=(
f"{label}: max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}, "
f"exceed_atol({atol})={n_exceed}/{n_total} ({pct_exceed:.2f}%)"
),
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_cuda(): def setup_cuda():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -61,6 +116,9 @@ def setup_cuda(): ...@@ -61,6 +116,9 @@ def setup_cuda():
@pytest.mark.parametrize("num_tokens", [1, 32]) @pytest.mark.parametrize("num_tokens", [1, 32])
@pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)]) @pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@pytest.mark.skipif( @pytest.mark.skipif(
is_torch_equal_or_newer("2.10.0"), is_torch_equal_or_newer("2.10.0"),
reason="Test fails with PyTorch 2.10.0 see: https://github.com/vllm-project/vllm/issues/33995", reason="Test fails with PyTorch 2.10.0 see: https://github.com/vllm-project/vllm/issues/33995",
...@@ -70,14 +128,24 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -70,14 +128,24 @@ def test_routed_input_transform_inside_vs_outside(
hidden_size: int, hidden_size: int,
latent_size: int, latent_size: int,
dtype: torch.dtype, dtype: torch.dtype,
use_rocm_aiter: bool,
dist_init, dist_init,
workspace_init, workspace_init,
monkeypatch,
): ):
"""Compare SharedFusedMoE with transform inside vs manually applying outside. """Compare SharedFusedMoE with transform inside vs manually applying outside.
Method A (inside): SharedFusedMoE with routed_input_transform Method A (inside): SharedFusedMoE with routed_input_transform
Method B (outside): Manually transform, then SharedFusedMoE without transform Method B (outside): Manually transform, then SharedFusedMoE without transform
""" """
if current_platform.is_rocm() and use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "1" if use_rocm_aiter else "0")
from vllm._aiter_ops import rocm_aiter_ops
rocm_aiter_ops.refresh_env_variables()
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42)
num_experts = 8 num_experts = 8
top_k = 2 top_k = 2
...@@ -125,7 +193,13 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -125,7 +193,13 @@ def test_routed_input_transform_inside_vs_outside(
prefix="moe_without_transform", prefix="moe_without_transform",
) )
# Weights are created via torch.empty (uninitialized).
# Initialize with seeded random values for reproducibility.
with torch.no_grad(): with torch.no_grad():
moe_with_transform.w13_weight.normal_()
moe_with_transform.w13_weight.div_(10)
moe_with_transform.w2_weight.normal_()
moe_with_transform.w2_weight.div_(10)
moe_without_transform.w13_weight.copy_(moe_with_transform.w13_weight) moe_without_transform.w13_weight.copy_(moe_with_transform.w13_weight)
moe_without_transform.w2_weight.copy_(moe_with_transform.w2_weight) moe_without_transform.w2_weight.copy_(moe_with_transform.w2_weight)
...@@ -139,9 +213,14 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -139,9 +213,14 @@ def test_routed_input_transform_inside_vs_outside(
hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
# Clone inputs so any in-place modification by Method A
# cannot affect Method B's computation.
hidden_states_A = hidden_states.clone()
router_logits_A = router_logits.clone()
with set_forward_context(None, vllm_config, num_tokens=num_tokens): with set_forward_context(None, vllm_config, num_tokens=num_tokens):
shared_out_A, routed_out_A = moe_with_transform( shared_out_A, routed_out_A = moe_with_transform(
hidden_states, router_logits hidden_states_A, router_logits_A
) )
transformed_hidden = routed_transform(hidden_states) transformed_hidden = routed_transform(hidden_states)
...@@ -149,19 +228,19 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -149,19 +228,19 @@ def test_routed_input_transform_inside_vs_outside(
transformed_hidden, router_logits transformed_hidden, router_logits
) )
torch.testing.assert_close( expected_shared_out = shared_experts(hidden_states)
_assert_close(
routed_out_A, routed_out_A,
routed_out_B, routed_out_B,
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
msg="Routed output should match: transform inside vs outside", label="Routed output: transform inside vs outside",
) )
_assert_close(
expected_shared_out = shared_experts(hidden_states)
torch.testing.assert_close(
shared_out_A, shared_out_A,
expected_shared_out, expected_shared_out,
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
label="Shared expert output",
) )
...@@ -10,18 +10,15 @@ import torch ...@@ -10,18 +10,15 @@ import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
persistent_masked_m_silu_mul_quant, persistent_masked_m_silu_mul_quant,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
if current_platform.is_fp8_fnuz(): fp8_dtype = current_platform.fp8_dtype()
pytest.skip(
"Tests in this file require float8_e4m3fn and platform does not support",
allow_module_level=True,
)
fp8_dtype = torch.float8_e4m3fn
CASES = [ CASES = [
(1, 1, 128, fp8_dtype), (1, 1, 128, fp8_dtype),
...@@ -58,22 +55,21 @@ def as_uint8(x) -> torch.Tensor: ...@@ -58,22 +55,21 @@ def as_uint8(x) -> torch.Tensor:
def silu(x: torch.Tensor) -> torch.Tensor: def silu(x: torch.Tensor) -> torch.Tensor:
one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32)
x_f32 = x.to(torch.float32) x_f32 = x.to(torch.float32)
act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32)) act_f32 = x_f32 / (1.0 + torch.exp(-x_f32))
assert act_f32.dtype == torch.float32 if current_platform.is_cuda():
# C++ kernel returns bf16
return act_f32.to(torch.bfloat16) return act_f32.to(torch.bfloat16)
# Triton fallback stays in f32
return act_f32
def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool): def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
fp8_min_val, fp8_max_val = get_fp8_min_max()
eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16) eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16)
one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16) one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16)
fp8_max_bf16 = torch.tensor( fp8_max_bf16 = torch.tensor([fp8_max_val], device=x.device, dtype=torch.bfloat16)
[torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16 fp8_min_bf16 = torch.tensor([fp8_min_val], device=x.device, dtype=torch.bfloat16)
)
fp8_min_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16
)
fp8_max_inv = one_bf16 / fp8_max_bf16 fp8_max_inv = one_bf16 / fp8_max_bf16
assert fp8_max_inv.dtype == torch.bfloat16 assert fp8_max_inv.dtype == torch.bfloat16
...@@ -81,6 +77,8 @@ def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool): ...@@ -81,6 +77,8 @@ def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
num_groups = x.numel() // group_size num_groups = x.numel() // group_size
x_og_shape = x.shape x_og_shape = x.shape
if current_platform.is_cuda():
# C++ kernel computes entirely in bf16
x = x.to(torch.bfloat16) x = x.to(torch.bfloat16)
x = x.view((-1, group_size)) x = x.view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=eps_bf16) amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
...@@ -94,9 +92,21 @@ def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool): ...@@ -94,9 +92,21 @@ def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
inv_s = one_bf16 / s inv_s = one_bf16 / s
inv_s = inv_s.view((num_groups, 1)) inv_s = inv_s.view((num_groups, 1))
xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to( xq = torch.clamp(
fp8_dtype x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()
) ).to(fp8_dtype)
else:
# Triton fallback computes in f32. Use multiply-by-reciprocal
# to match Triton's constexpr evaluation of 1.0/fp8_max.
fp8_min_f, fp8_max_f = get_fp8_min_max()
x = x.to(torch.float32).view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=1e-10)
s = amax * (1.0 / fp8_max_f)
if ceil_ue8m0:
s = torch.exp2(torch.ceil(torch.log2(s)))
inv_s = (1.0 / s).view((num_groups, 1))
xq = torch.clamp(x * inv_s, min=fp8_min_f, max=fp8_max_f).to(fp8_dtype)
xq = xq.view(x_og_shape) xq = xq.view(x_og_shape)
xs = s.view((-1, xq.size(-1) // group_size)) xs = s.view((-1, xq.size(-1) // group_size))
...@@ -112,12 +122,10 @@ def silu_mul_quant( ...@@ -112,12 +122,10 @@ def silu_mul_quant(
assert gate.dtype == torch.bfloat16 assert gate.dtype == torch.bfloat16
assert up.dtype == torch.bfloat16 assert up.dtype == torch.bfloat16
act_bf16 = silu(gate) act = silu(gate)
assert act_bf16.dtype == torch.bfloat16
# act & mul # act & mul
a_m = act_bf16 * up a_m = act * up
assert a_m.dtype == torch.bfloat16
q, s = do_quant(a_m, group_size, ceil_ue8m0) q, s = do_quant(a_m, group_size, ceil_ue8m0)
return q, s return q, s
...@@ -221,8 +229,12 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt ...@@ -221,8 +229,12 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
scale_fmts = [ scale_fmts = [
DeepGemmQuantScaleFMT.FLOAT32, DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
] ]
# UE8M0 (int32 packed) scales require the C++ kernel which is
# not available on ROCm (#ifndef USE_ROCM).
# https://github.com/ROCm/aiter/issues/2420
if current_platform.is_cuda():
scale_fmts.append(DeepGemmQuantScaleFMT.UE8M0)
# Run the SiLU V2 kernel # Run the SiLU V2 kernel
for scale_fmt in scale_fmts: for scale_fmt in scale_fmts:
...@@ -274,6 +286,19 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt ...@@ -274,6 +286,19 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
for e in range(E): for e in range(E):
nt = tokens_per_expert[e].item() nt = tokens_per_expert[e].item()
if current_platform.is_rocm():
# On ROCm the Triton fallback kernel uses f32 math
# intrinsics (tl.exp) that may differ from PyTorch's
# torch.exp by 1 ULP. At FP8 quantization
# boundaries this can flip one representable value.
# Allow 1 FP8 quantum of tolerance.
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32),
atol=32.0,
rtol=0.2,
)
else:
torch.testing.assert_close( torch.testing.assert_close(
y_q[e, :nt].to(torch.float32), y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32), ref_y_q[e, :nt].to(torch.float32),
......
...@@ -16,7 +16,7 @@ from vllm.platforms import current_platform ...@@ -16,7 +16,7 @@ from vllm.platforms import current_platform
"platform_method,expected_backend", "platform_method,expected_backend",
[ [
("is_cuda", UnquantizedMoeBackend.TRITON), # Default CUDA without FlashInfer ("is_cuda", UnquantizedMoeBackend.TRITON), # Default CUDA without FlashInfer
("is_rocm", UnquantizedMoeBackend.TRITON), ("is_rocm", UnquantizedMoeBackend.TRITON), # ROCm without AITER
("is_cpu", UnquantizedMoeBackend.CPU), ("is_cpu", UnquantizedMoeBackend.CPU),
("is_xpu", UnquantizedMoeBackend.XPU), ("is_xpu", UnquantizedMoeBackend.XPU),
("is_tpu", UnquantizedMoeBackend.TPU), ("is_tpu", UnquantizedMoeBackend.TPU),
...@@ -27,13 +27,19 @@ from vllm.platforms import current_platform ...@@ -27,13 +27,19 @@ from vllm.platforms import current_platform
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=False, return_value=False,
) )
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.rocm_aiter_ops.is_fused_moe_enabled",
return_value=False,
)
def test_select_default_backend_by_platform( def test_select_default_backend_by_platform(
mock_aiter_enabled,
mock_has_flashinfer, mock_has_flashinfer,
monkeypatch, monkeypatch,
platform_method, platform_method,
expected_backend, expected_backend,
): ):
"""Test backend selection for different platforms.""" """Test default backend selection per platform with all optional
accelerators (FlashInfer, AITER) disabled."""
with patch( with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform" "vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform: ) as mock_platform:
...@@ -58,6 +64,39 @@ def test_select_default_backend_by_platform( ...@@ -58,6 +64,39 @@ def test_select_default_backend_by_platform(
assert selected_backend == expected_backend assert selected_backend == expected_backend
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=False,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.rocm_aiter_ops.is_fused_moe_enabled",
return_value=True,
)
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="ROCm-specific backend selection test"
)
def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer):
"""Test ROCm backend selection when AITER is available."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
mock_platform.is_cuda.return_value = False
mock_platform.is_rocm.return_value = True
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False
moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=False,
use_dp=False,
)
assert selected_backend == UnquantizedMoeBackend.AITER
@patch( @patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer", "vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=True, return_value=True,
......
...@@ -941,7 +941,7 @@ def torch_experts( ...@@ -941,7 +941,7 @@ def torch_experts(
if b_bias1 is not None: if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype) tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
tmp2 = SiluAndMul()(tmp1).to(out.dtype) tmp2 = act()(tmp1).to(out.dtype)
tmp2, b_scale = moe_kernel_quantize_input( tmp2, b_scale = moe_kernel_quantize_input(
tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
get_fp8_min_max,
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
kFp8Static128BlockSym, kFp8Static128BlockSym,
) )
...@@ -117,7 +118,10 @@ def _silu_mul_fp8_quant_deep_gemm( ...@@ -117,7 +118,10 @@ def _silu_mul_fp8_quant_deep_gemm(
gate = gate * (1.0 / (1.0 + tl.exp(-gate))) gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
y = gate * up y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max # Use multiply-by-reciprocal to match PyTorch's tensor/scalar
# division precision (Triton GPU fast-division for constexpr
# divisors can introduce 1-ULP error).
y_s = tl.maximum(tl.max(tl.abs(y)), eps) * (1.0 / fp8_max)
if ceil_ue8m0: if ceil_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s))) y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
...@@ -190,7 +194,7 @@ def persistent_masked_m_silu_mul_quant( ...@@ -190,7 +194,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
fp8_dtype = torch.float8_e4m3fn fp8_dtype = current_platform.fp8_dtype()
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt) ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
...@@ -210,11 +214,14 @@ def persistent_masked_m_silu_mul_quant( ...@@ -210,11 +214,14 @@ def persistent_masked_m_silu_mul_quant(
device_id=y.device.index device_id=y.device.index
).to_int() ).to_int()
if cuda_arch >= 80: if current_platform.is_cuda() and cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant( torch.ops._C.persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, y_q, y_s, ceil_ue8m0 y, tokens_per_expert, y_q, y_s, ceil_ue8m0
) )
else: else:
# Triton fallback for ROCm -- the C++ kernel is guarded by
# #ifndef USE_ROCM in activation_kernels.cu.
# https://github.com/ROCm/aiter/issues/2420
stride_cnt_e = tokens_per_expert.stride()[0] stride_cnt_e = tokens_per_expert.stride()[0]
# Static grid over experts and H-groups. # Static grid over experts and H-groups.
...@@ -224,13 +231,11 @@ def persistent_masked_m_silu_mul_quant( ...@@ -224,13 +231,11 @@ def persistent_masked_m_silu_mul_quant(
stride_i_e, stride_i_t, stride_i_h = y.stride() stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
f_info = torch.finfo(fp8_dtype) fp8_min, fp8_max = get_fp8_min_max()
fp8_max = f_info.max
fp8_min = f_info.min
eps: float = 1e-10 eps: float = 1e-10
assert y_s.dtype == torch.float32, ( assert y_s.dtype == torch.float32, (
"_silu_mul_fp8_quant_deep_gemm does" "_silu_mul_fp8_quant_deep_gemm Triton fallback does not "
"not support {y_s.dtype} scales. Only torch.float32 supported." f"support {y_s.dtype} scales. Only torch.float32 supported."
) )
_silu_mul_fp8_quant_deep_gemm[grid]( _silu_mul_fp8_quant_deep_gemm[grid](
y, y,
......
...@@ -253,10 +253,16 @@ def triton_kernel_moe_forward( ...@@ -253,10 +253,16 @@ def triton_kernel_moe_forward(
logits = gating_output logits = gating_output
if sm_first: if sm_first:
logits = torch.softmax(logits, dim=-1) logits = torch.softmax(logits, dim=-1)
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first) topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# sparse_logits.indx contains global expert IDs – remap to local. # topk may return a tuple (vals, indx, bitmatrix) or a
topk_ids = expert_map[sparse_logits.indx.to(torch.long)] # SparseMatrix depending on the triton_kernels version.
topk_weights = sparse_logits.vals if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, _ = topk_result
else:
topk_weights = topk_result.vals
topk_ids_raw = topk_result.indx
# topk_ids_raw contains global expert IDs - remap to local.
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
local_num_experts = w1.shape[0] local_num_experts = w1.shape[0]
routing_data, gather_idx, scatter_idx = make_routing_data( routing_data, gather_idx, scatter_idx = make_routing_data(
topk_ids, topk_weights, local_num_experts topk_ids, topk_weights, local_num_experts
...@@ -422,8 +428,13 @@ def triton_kernel_fused_mxfp4_w4a8_experts( ...@@ -422,8 +428,13 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4 # Shape check: when weights are padded (e.g. hidden_size padded for
assert hidden_states.shape[-1] == w1.shape[-2] # GFX950 swizzle), unpadded_K_w1 carries the original dimension.
expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2]
assert hidden_states.shape[-1] == expected_K_w1, (
f"hidden_states K={hidden_states.shape[-1]} != "
f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})"
)
assert w2.shape[-1] == w1.shape[1] assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape E, _, N = w1.shape
...@@ -483,6 +494,12 @@ def triton_kernel_fused_mxfp4_w4a8_experts( ...@@ -483,6 +494,12 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
unpadded_K=unpadded_K_w2, unpadded_K=unpadded_K_w2,
) )
# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
# the kernel output has the padded dimension. Slice back to the
# original hidden_size so downstream layers see the expected shape.
if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2:
intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous()
return intermediate_cache3 return intermediate_cache3
......
...@@ -741,11 +741,14 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -741,11 +741,14 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# TP=4 yields intermediate_size_per_partition=384), AITER raises: # TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem". # "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case. # Fall back to emulation in that case.
# For gpt_oss models, create_weights rounds up the dimensions
# internally, so the alignment check is skipped.
if ( if (
not self.emulate not self.emulate
and self.use_rocm_aiter_moe and self.use_rocm_aiter_moe
and self.ocp_mx_scheme is not None and self.ocp_mx_scheme is not None
and self.ocp_mx_scheme.startswith("w_mxfp4") and self.ocp_mx_scheme.startswith("w_mxfp4")
and self.model_type != "gpt_oss"
and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0 and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0
): ):
logger.warning_once( logger.warning_once(
...@@ -819,6 +822,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -819,6 +822,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"unpadded_hidden_size", hidden_size "unpadded_hidden_size", hidden_size
) )
# On GFX950, the GFX950MXScaleLayout swizzle requires
# hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32
# must be divisible by 8). Pad hidden_size for weight/scale
# allocation; the original value is preserved in unpadded_hidden_size.
# Only applies to the native (non-emulated) CK path on GFX950.
if (
self.model_type == "gpt_oss"
and current_platform.is_rocm()
and not self.emulate
):
hidden_size = round_up(hidden_size, 256)
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
......
...@@ -615,8 +615,8 @@ def _per_token_group_quant_fp8( ...@@ -615,8 +615,8 @@ def _per_token_group_quant_fp8(
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for float8
fp8_min, fp8_min: tl.constexpr,
fp8_max, fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr, use_ue8m0: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
...@@ -647,8 +647,12 @@ def _per_token_group_quant_fp8( ...@@ -647,8 +647,12 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
# Use multiply-by-reciprocal instead of division to match PyTorch's
# tensor/scalar division precision (GPU fast-division for constexpr
# divisors can introduce 1-ULP error that flips FP8 quantization at
# representable-value boundaries).
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max scale_raw = _absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
...@@ -667,8 +671,8 @@ def _silu_mul_per_token_group_quant_fp8_colmajor( ...@@ -667,8 +671,8 @@ def _silu_mul_per_token_group_quant_fp8_colmajor(
y_s_col_stride: tl.int64, y_s_col_stride: tl.int64,
# Information for float8 # Information for float8
eps, eps,
fp8_min, fp8_min: tl.constexpr,
fp8_max, fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr, use_ue8m0: tl.constexpr,
# Meta-parameters # Meta-parameters
GROUP_SIZE: tl.constexpr, GROUP_SIZE: tl.constexpr,
...@@ -709,7 +713,7 @@ def _silu_mul_per_token_group_quant_fp8_colmajor( ...@@ -709,7 +713,7 @@ def _silu_mul_per_token_group_quant_fp8_colmajor(
# quant # quant
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps) _absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
scale_raw = _absmax / fp8_max scale_raw = _absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_s = tl.reshape(y_s, (BLOCK_M, 1)) y_s = tl.reshape(y_s, (BLOCK_M, 1))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
...@@ -808,8 +812,8 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -808,8 +812,8 @@ def _per_token_group_quant_fp8_colmajor(
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for float8
fp8_min, fp8_min: tl.constexpr,
fp8_max, fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr, use_ue8m0: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
...@@ -849,7 +853,7 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -849,7 +853,7 @@ def _per_token_group_quant_fp8_colmajor(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
scale_raw = _absmax / fp8_max scale_raw = _absmax * (1.0 / fp8_max)
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment