Unverified Commit 11fd69dd authored by smit kadvani's avatar smit kadvani Committed by GitHub
Browse files

[amd][gptoss] Perf gain because of block alignment (#28024)


Signed-off-by: default avatarSmit Kadvani <smit.kadvani@gmail.com>
Co-authored-by: default avatarSmit Shaileshbhai Kadvani <kadvani@meta.com>
parent c0a4b95d
......@@ -43,6 +43,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4,
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import set_weight_attrs
......@@ -282,10 +283,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
pad_align = get_padding_alignment()
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
intermediate_size_per_partition, pad_align
)
hidden_size = round_up(hidden_size, 256)
hidden_size = round_up(hidden_size, pad_align)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64
......
......@@ -7,6 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
......@@ -99,6 +100,14 @@ def _can_support_mxfp4(
)
def get_padding_alignment():
return (
256
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
else 128
)
def _dequant_mxfp4(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment