Unverified Commit 8ad7285e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra...


[Kernels] Clean up FusedMoeMethodBase and modular kernel setup.  Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 48b01fd4
...@@ -399,6 +399,7 @@ steps: ...@@ -399,6 +399,7 @@ steps:
- label: Kernels MoE Test %N - label: Kernels MoE Test %N
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/quantization/cutlass_w8a8/moe/
- csrc/moe/ - csrc/moe/
- tests/kernels/moe - tests/kernels/moe
- vllm/model_executor/layers/fused_moe/ - vllm/model_executor/layers/fused_moe/
......
...@@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking ...@@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
### FusedMoEModularKernel Initialization ### FusedMoEModularKernel Initialization
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, `FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
* maybe_make_prepare_finalize,
* select_gemm_impl, and * select_gemm_impl, and
* init_prepare_finalize * init_prepare_finalize
#### maybe_make_prepare_finalize
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
Please refer to the implementations in,
* `ModelOptNvFp4FusedMoE`
#### select_gemm_impl #### select_gemm_impl
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
......
...@@ -70,12 +70,27 @@ def parse_args(): ...@@ -70,12 +70,27 @@ def parse_args():
default=64, default=64,
help=("Maximum number of sequences to be processed in a single iteration."), help=("Maximum number of sequences to be processed in a single iteration."),
) )
parser.add_argument(
"--max-model-len",
type=int,
help=("Maximum number of tokens to be processed in a single iteration."),
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help=("Number of seconds before unresponsive process is killed."),
)
parser.add_argument( parser.add_argument(
"--gpu-memory-utilization", "--gpu-memory-utilization",
type=float, type=float,
default=0.8, default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
) )
parser.add_argument(
"--quantization",
type=str,
)
return parser.parse_args() return parser.parse_args()
...@@ -90,7 +105,9 @@ def main( ...@@ -90,7 +105,9 @@ def main(
enforce_eager, enforce_eager,
trust_remote_code, trust_remote_code,
max_num_seqs, max_num_seqs,
max_model_len,
gpu_memory_utilization, gpu_memory_utilization,
quantization,
): ):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
...@@ -142,7 +159,9 @@ def main( ...@@ -142,7 +159,9 @@ def main(
enable_expert_parallel=True, enable_expert_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
quantization=quantization,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
...@@ -198,14 +217,16 @@ if __name__ == "__main__": ...@@ -198,14 +217,16 @@ if __name__ == "__main__":
args.enforce_eager, args.enforce_eager,
args.trust_remote_code, args.trust_remote_code,
args.max_num_seqs, args.max_num_seqs,
args.max_model_len,
args.gpu_memory_utilization, args.gpu_memory_utilization,
args.quantization,
), ),
) )
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0
for proc in procs: for proc in procs:
proc.join(timeout=300) proc.join(timeout=args.timeout)
if proc.exitcode is None: if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill() proc.kill()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch import torch
# Fused experts and PrepareFinalize imports # Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts) BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts) BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_pplx from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if has_deep_ep():
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
backend: Optional[str]
supports_apply_weight_on_input: bool = True
@dataclass
class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
def register_prepare_and_finalize(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
backend: Optional[str],
force_multigpu: bool = False,
supports_apply_weight_on_input: bool = True,
):
global PREPARE_FINALIZE_INFO
global MK_ALL_PREPARE_FINALIZE_TYPES
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
assert kind not in PREPARE_FINALIZE_INFO
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
backend,
supports_apply_weight_on_input,
)
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
if backend is not None or force_multigpu:
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
else:
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
def register_experts(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
):
global EXPERT_INFO
global MK_FUSED_EXPERT_TYPES
assert kind not in EXPERT_INFO
EXPERT_INFO[kind] = ExpertInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_chunking,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
)
MK_FUSED_EXPERT_TYPES.append(kind)
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
info = PREPARE_FINALIZE_INFO.get(kind)
assert info is not None
return info
def expert_info(kind) -> ExpertInfo:
info = EXPERT_INFO.get(kind)
assert info is not None
return info
register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend=None,
)
register_experts(
BatchedTritonExperts,
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
)
register_experts(
TritonExperts,
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
)
register_experts(
NaiveBatchedExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True,
)
# Disable on blackwell for now
if has_deep_ep() and not current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_high_throughput",
)
register_prepare_and_finalize(
DeepEPLLPrepareAndFinalize,
batched_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_low_latency",
)
if has_pplx(): if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
backend="pplx",
)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] if (has_flashinfer_cutlass_fused_moe()
if has_pplx(): and current_platform.has_device_capability(100)):
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
if has_deep_ep(): FlashInferExperts)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize FlashInferCutlassMoEPrepareAndFinalize)
]
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
supports_apply_weight_on_input=False,
)
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + register_experts(
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) FlashInferExperts,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
else:
FlashInferCutlassMoEPrepareAndFinalize = None
MK_FUSED_EXPERT_TYPES = [ if has_deep_gemm() and is_deep_gemm_supported():
BatchedDeepGemmExperts, register_experts(
BatchedTritonExperts, BatchedDeepGemmExperts,
NaiveBatchedExperts, batched_format,
BatchedTritonOrDeepGemmExperts, fp8_types,
CutlassExpertsFp8, blocked_quantization_support=True,
DeepGemmExperts, supports_chunking=False,
TritonOrDeepGemmExperts, supports_expert_map=False,
TritonExperts, needs_matching_quant=False,
] needs_deep_gemm=True,
)
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts,
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
)
if cutlass_fp8_supported():
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
CutlassExpertsFp8)
register_experts(
CutlassExpertsFp8,
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False,
)
register_experts(
CutlassBatchedExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False,
)
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4)
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [ MK_QUANT_CONFIGS = [
None, None,
...@@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [ ...@@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [
# block-quantized weights and per-token activations # block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations # block-quantized weights and per-tensor activations
] ]
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
return t[s:e]
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"num_dispatchers": num_dispatchers,
}
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
return experts
...@@ -52,7 +52,7 @@ def profile_modular_kernel( ...@@ -52,7 +52,7 @@ def profile_modular_kernel(
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
# make modular kernel # make modular kernel
mk = make_modular_kernel(config, vllm_config) mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = { mk_kwargs = {
"hidden_states": rank_tensors.hidden_states, "hidden_states": rank_tensors.hidden_states,
...@@ -83,7 +83,7 @@ def rank_worker( ...@@ -83,7 +83,7 @@ def rank_worker(
# sanity check # sanity check
from vllm import envs from vllm import envs
if config.fused_moe_chunk_size is not None: if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm._custom_ops as ops
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_non_quant_weights(
e: int,
n: int,
k: int,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2
"""
device = torch.cuda.current_device()
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
return w1, w2
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
dtype = torch.bfloat16
device = torch.cuda.current_device()
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device=device,
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device=device,
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
(k + (block_k - 1)) // block_k)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=[block_k, block_n])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=[block_k, block_n])
return w1, w2, w1_s, w2_s
def make_quant_fp8_weights(
e: int,
n: int,
k: int,
per_out_channel_quant: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return w1, w2, w1_scale, w2_scale
"""
q_dtype = torch.float8_e4m3fn
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
# w1 -> w1_q, w2 -> w2_q
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
n_b_scales = 2 * n if per_out_channel_quant else 1
k_b_scales = k if per_out_channel_quant else 1
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
return w1_q, w2_q, w1_scale, w2_scale
...@@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ...@@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
) )
B, B_q, B_scale, _, _, _ = make_test_weights( (B, B_q, B_scale, _), _ = make_test_weights(
num_experts, num_experts,
N // 2, N // 2,
K, K,
...@@ -243,7 +243,7 @@ def test_fused_moe_batched_experts( ...@@ -243,7 +243,7 @@ def test_fused_moe_batched_experts(
act_dtype = dtype act_dtype = dtype
quant_dtype = None quant_dtype = None
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
......
...@@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, ...@@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
N, _) = make_test_weights(E,
K, N,
dtype, K,
torch.float8_e4m3fn, dtype,
per_act_token_quant=False, torch.float8_e4m3fn,
block_shape=block_size) per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=block_size) block_shape=block_size)
...@@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, ...@@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
N, _) = make_test_weights(E,
K, N,
dtype, K,
torch.float8_e4m3fn, dtype,
per_act_token_quant=False, torch.float8_e4m3fn,
block_shape=block_size) per_act_token_quant=False,
block_shape=block_size)
# Note: for now use_compile will error out if the problem size is # Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and # large enough to trigger chunking. I'm leaving the flag and
......
...@@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
N, _) = make_test_weights(E,
K, N,
dtype, K,
torch.int8, dtype,
per_act_token_quant=False, torch.int8,
block_shape=block_size) per_act_token_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
......
...@@ -9,6 +9,7 @@ import random ...@@ -9,6 +9,7 @@ import random
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import per_token_cast_to_fp8
from tests.kernels.utils import baseline_scaled_mm from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -16,20 +17,6 @@ from vllm.utils import cdiv ...@@ -16,20 +17,6 @@ from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view *
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096), (4, 8192, 7168, 4096),
(4, 8192, 2048, 7168), (4, 8192, 2048, 7168),
...@@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm( ...@@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm(
device=device, device=device,
dtype=torch.float)) dtype=torch.float))
for i in range(num_groups): for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups): for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
......
...@@ -70,8 +70,10 @@ def make_block_quant_fp8_weights( ...@@ -70,8 +70,10 @@ def make_block_quant_fp8_weights(
""" """
Return weights w1q, w2q, w1_scale, w2_scale Return weights w1q, w2q, w1_scale, w2_scale
""" """
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( (_, w1q, w1_scale, _), (_, w2q, w2_scale,
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) _) = make_test_weights(e, n, k, torch.bfloat16,
torch.float8_e4m3fn,
block_size)
return w1q, w2q, w1_scale, w2_scale return w1q, w2q, w1_scale, w2_scale
......
...@@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
# Note: W1 has shape (E, 2N, K), so N = 512 # Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path. # can trigger the deepgemm path.
MNKs = [ MNKs = [
(1024, 512, 128), (1024, 768, 128),
(1024, 512, 512), (1024, 768, 512),
(2048, 512, 512), (2048, 768, 512),
(512, 1024, 1024), (512, 1024, 1024),
(512, 2048, 2048), (512, 2048, 2048),
(4096, 4096, 1024), (4096, 4096, 1024),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
a1_gscale=a1_gs,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
flashinfer_output,
atol=1e-1,
rtol=1e-1)
if __name__ == "__main__":
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
import textwrap
import traceback
from itertools import product from itertools import product
from typing import Optional from typing import Optional
...@@ -10,41 +12,51 @@ import torch ...@@ -10,41 +12,51 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl, reference_moe_impl,
run_modular_kernel) run_modular_kernel)
from .modular_kernel_tools.mk_objects import ( from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config) parallel_launch_with_config)
# TODO (varun): These requirements are very strict and could be relaxed. has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) or has_flashinfer_cutlass_fused_moe())
meets_package_requirements = pytest.mark.skipif( meets_multi_gpu_requirements = pytest.mark.skipif(
not has_all_packages, not has_any_multi_gpu_package,
reason="Requires deep_ep & deep_gemm & pplx packages", reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
) )
def format_result(verbose, msg, ex=None):
if ex is not None:
x = str(ex)
newx = x.strip(" \n\t")[:16]
if len(newx) < len(x):
newx = newx + " ..."
prefix = "E\t"
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
print(f"FAILED {msg} - {newx}\n")
elif verbose:
print(f"PASSED {msg}")
else:
print(".", end="")
def rank_worker( def rank_worker(
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
vllm_config: VllmConfig, vllm_config: VllmConfig,
cpu_group, cpu_group,
config: Config, config: Config,
weights: WeightTensors, weights: WeightTensors,
verbose: bool,
): ):
current_platform.seed_everything(pgi.rank) current_platform.seed_everything(pgi.rank)
...@@ -61,39 +73,64 @@ def rank_worker( ...@@ -61,39 +73,64 @@ def rank_worker(
TOPKs = config.topks TOPKs = config.topks
assert isinstance(TOPKs, list) assert isinstance(TOPKs, list)
for m, topk in product(Ms, TOPKs): exceptions = []
print(f"Running m={m}, topk={topk} ...") count = 0
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
with set_current_vllm_config(vllm_config): for m, topk in product(Ms, TOPKs):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors) try:
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) count = count + 1
# override m and topk
cfgx = copy.deepcopy(config)
def run(config: Config): cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
if config.quant_dtype == "nvfp4":
atol = 1e-1
rtol = 1e-1
else:
atol = 3e-2
rtol = 3e-2
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
format_result(verbose, config.describe())
except Exception as ex:
format_result(verbose, config.describe(), ex)
exceptions.append(ex)
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
def run(config: Config, verbose: bool):
assert config.is_valid() assert config.is_valid()
print(f"Testing config \n{config.describe()} ...")
weights: WeightTensors = WeightTensors.make(config) weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data() vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config, parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights) env_dict, config, weights, verbose)
Ms = [32, 64] Ms = [32, 64]
Ks = [7168] # hidden sizes # hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
Ns = [2048] Ns = [2048]
TOPKs = [4, 1] TOPKs = [4, 1]
Es = [32] Es = [32]
...@@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16] ...@@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
def is_nyi_config(config: Config) -> bool: def is_nyi_config(config: Config) -> bool:
# We know these configs to be legitimate. but still fail. # We know these configs to be legitimate. but still fail.
info = expert_info(config.fused_experts_type)
if (config.fused_experts_type in [ if info.needs_matching_quant:
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
TritonExperts, TritonOrDeepGemmExperts
]):
# The triton kernels expect both per-act-token-quant and # The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither. # per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant + unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1) config.is_per_out_ch_quant) == 1)
return unsupported_quant_config return unsupported_quant_config
# cutlass kernels dont support expert_maps yet. return not info.supports_expert_map
return config.fused_experts_type == CutlassExpertsFp8
@pytest.mark.parametrize("k", Ks) @pytest.mark.parametrize("k", Ks)
...@@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool: ...@@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool:
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@meets_package_requirements @meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu( def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config( config = Config(
Ms=Ms, Ms=Ms,
...@@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu( ...@@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu(
fused_moe_chunk_size=fused_moe_chunk_size, fused_moe_chunk_size=fused_moe_chunk_size,
world_size=world_size, world_size=world_size,
) )
if not config.is_valid(): if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...") pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config): if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...") pytest.skip(f"Tests config {config} is nyi. Skipping ...")
print(f"{config.describe()}") verbosity = pytestconfig.getoption('verbose')
run(config) run(config, verbosity > 0)
@pytest.mark.parametrize("k", Ks) @pytest.mark.parametrize("k", Ks)
...@@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu( ...@@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu(
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@meets_package_requirements
def test_modular_kernel_combinations_singlegpu( def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config( config = Config(
Ms=Ms, Ms=Ms,
K=k, K=k,
...@@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu(
if is_nyi_config(config): if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...") pytest.skip(f"Tests config {config} is nyi. Skipping ...")
run(config) verbosity = pytestconfig.getoption('verbose')
run(config, verbosity > 0)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -211,4 +246,4 @@ if __name__ == '__main__': ...@@ -211,4 +246,4 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
config = make_config(args) config = make_config(args)
run(config) run(config, True)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype) dequantize_nvfp4_to_dtype)
...@@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
VllmConfig(parallel_config=ParallelConfig( VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))): pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16 quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), (_, w1_q, w1_blockscale,
device="cuda", w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
dtype=torch.float8_e4m3fn) e,
n,
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 k,
sf_w2_k = round_up(k, 128) in_dtype=dtype,
sf_w2_n = round_up(n // quant_blocksize, 4) quant_dtype="nvfp4",
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), block_shape=None, # use quant_blocksize?
device="cuda", per_act_token_quant=False,
dtype=torch.float8_e4m3fn) )
w1_q = torch.empty((e, 2 * n, k // 2),
device="cuda",
dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, topk_weights, topk_ids, _ = fused_topk(a,
...@@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
cutlass_output = cutlass_moe_fp4( cutlass_output = cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs, a1_gscale=a1_gs,
...@@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
n=n, n=n,
k=k, k=k,
e=e, e=e,
device=a.device,
) )
# Reference check: # Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32) torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved, a_scale_interleaved,
a_global_scale, a_global_scale,
...@@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx], w1_blockscale[idx],
w1_gs[idx], w1_gs[idx],
dtype=w1.dtype, dtype=dtype,
device=w1.device, device=w1_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx], w2_blockscale[idx],
w2_gs[idx], w2_gs[idx],
dtype=w2.dtype, dtype=dtype,
device=w2.device, device=w2_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
......
...@@ -9,7 +9,8 @@ import torch ...@@ -9,7 +9,8 @@ import torch
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
...@@ -123,12 +124,8 @@ def pplx_cutlass_moe( ...@@ -123,12 +124,8 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers) num_dispatchers=num_dispatchers)
experts = CutlassExpertsFp8(num_local_experts, experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, out_dtype, per_act_token, per_out_ch)
per_act_token,
per_out_ch,
num_dispatchers=num_dispatchers,
use_batched_format=True)
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
......
...@@ -770,7 +770,7 @@ def test_pplx_moe_slow( ...@@ -770,7 +770,7 @@ def test_pplx_moe_slow(
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights( (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
...@@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, ...@@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
args = dict() args = dict()
if make_weights: if make_weights:
_, w1, w1_s, _, w2, w2_s = make_test_weights( (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Union
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
...@@ -169,28 +171,41 @@ def make_quantized_test_activations( ...@@ -169,28 +171,41 @@ def make_quantized_test_activations(
def moe_quantize_weights( def moe_quantize_weights(
w: torch.Tensor, w: torch.Tensor,
w_s: Optional[torch.Tensor], w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype], quant_dtype: Union[torch.dtype, str, None],
per_token_quant: bool, per_token_quant: bool,
block_shape: Optional[list[int]], block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
or quant_dtype == torch.int8), "only fp8/int8 supported" or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
w_gs = None
if block_shape is not None: if block_shape is not None:
assert not per_token_quant assert not per_token_quant
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape) w, w_s = per_block_cast_to_int8(w, block_shape)
else: elif quant_dtype == torch.float8_e4m3fn:
w, w_s = per_block_cast_to_fp8(w, block_shape) w, w_s = per_block_cast_to_fp8(w, block_shape)
elif quant_dtype == "nvfp4":
raise RuntimeError("blocked quantization not supported for nvfp4")
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
else: else:
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant( w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant) w, w_s, use_per_token_if_dynamic=per_token_quant)
else: elif quant_dtype == torch.float8_e4m3fn:
w, w_s = ops.scaled_fp8_quant( w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant) w, w_s, use_per_token_if_dynamic=per_token_quant)
elif quant_dtype == "nvfp4":
assert not per_token_quant
w_amax = torch.abs(w).max().to(torch.float32)
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
w, w_s = ops.scaled_fp4_quant(w, w_gs)
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
return w, w_s return w, w_s, w_gs
def make_test_weight( def make_test_weight(
...@@ -198,21 +213,26 @@ def make_test_weight( ...@@ -198,21 +213,26 @@ def make_test_weight(
rows: int, rows: int,
cols: int, cols: int,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
if quant_dtype is not None: if quant_dtype is not None:
w_l = [None] * e w_l = [None] * e
w_s_l = [None] * e w_s_l = [None] * e
w_gs_l = [None] * e
for idx in range(e): for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights( w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l) w = torch.stack(w_l)
w_s = torch.stack(w_s_l) w_s = torch.stack(w_s_l)
if e > 0 and w_gs_l[0] is not None:
w_gs = torch.stack(w_gs_l)
if w_s.ndim == 2: if w_s.ndim == 2:
assert w_s.shape[-1] == 1 assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1) w_s = w_s.view(-1, 1, 1)
...@@ -225,8 +245,9 @@ def make_test_weight( ...@@ -225,8 +245,9 @@ def make_test_weight(
else: else:
w = w_16 w = w_16
w_s = None w_s = None
w_gs = None
return w_16, w, w_s return w_16, w, w_s, w_gs
def make_test_weights( def make_test_weights(
...@@ -234,14 +255,30 @@ def make_test_weights( ...@@ -234,14 +255,30 @@ def make_test_weights(
n: int, n: int,
k: int, k: int,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
torch.Tensor, Optional[torch.Tensor]]: Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return ( return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_act_token_quant),
) )
def per_token_cast_to_fp8(
x: torch.Tensor,
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
...@@ -105,7 +105,8 @@ class DeviceCommunicatorBase: ...@@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
# we initialize the all2all manager used in expert parallel. # we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1 use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
...@@ -246,7 +247,7 @@ class DeviceCommunicatorBase: ...@@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
""" """
Prepare the communication buffer for the model. Prepare the communication buffer for the model.
""" """
if not self.use_all2all: if not self.is_ep_communicator:
return return
moe_modules = [ moe_modules = [
...@@ -254,7 +255,7 @@ class DeviceCommunicatorBase: ...@@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE" if module.__class__.__name__ == "FusedMoE"
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config) module.quant_method.init_prepare_finalize()
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment