"vllm/vscode:/vscode.git/clone" did not exist on "037a6487af3429bb3f3e1adfe3e2f5e5e95aa420"
Unverified Commit 8ad7285e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra...


[Kernels] Clean up FusedMoeMethodBase and modular kernel setup.  Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 48b01fd4
...@@ -399,6 +399,7 @@ steps: ...@@ -399,6 +399,7 @@ steps:
- label: Kernels MoE Test %N - label: Kernels MoE Test %N
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/quantization/cutlass_w8a8/moe/
- csrc/moe/ - csrc/moe/
- tests/kernels/moe - tests/kernels/moe
- vllm/model_executor/layers/fused_moe/ - vllm/model_executor/layers/fused_moe/
......
...@@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking ...@@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
### FusedMoEModularKernel Initialization ### FusedMoEModularKernel Initialization
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, `FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
* maybe_make_prepare_finalize,
* select_gemm_impl, and * select_gemm_impl, and
* init_prepare_finalize * init_prepare_finalize
#### maybe_make_prepare_finalize
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
Please refer to the implementations in,
* `ModelOptNvFp4FusedMoE`
#### select_gemm_impl #### select_gemm_impl
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
......
...@@ -70,12 +70,27 @@ def parse_args(): ...@@ -70,12 +70,27 @@ def parse_args():
default=64, default=64,
help=("Maximum number of sequences to be processed in a single iteration."), help=("Maximum number of sequences to be processed in a single iteration."),
) )
parser.add_argument(
"--max-model-len",
type=int,
help=("Maximum number of tokens to be processed in a single iteration."),
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help=("Number of seconds before unresponsive process is killed."),
)
parser.add_argument( parser.add_argument(
"--gpu-memory-utilization", "--gpu-memory-utilization",
type=float, type=float,
default=0.8, default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
) )
parser.add_argument(
"--quantization",
type=str,
)
return parser.parse_args() return parser.parse_args()
...@@ -90,7 +105,9 @@ def main( ...@@ -90,7 +105,9 @@ def main(
enforce_eager, enforce_eager,
trust_remote_code, trust_remote_code,
max_num_seqs, max_num_seqs,
max_model_len,
gpu_memory_utilization, gpu_memory_utilization,
quantization,
): ):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
...@@ -142,7 +159,9 @@ def main( ...@@ -142,7 +159,9 @@ def main(
enable_expert_parallel=True, enable_expert_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
quantization=quantization,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
...@@ -198,14 +217,16 @@ if __name__ == "__main__": ...@@ -198,14 +217,16 @@ if __name__ == "__main__":
args.enforce_eager, args.enforce_eager,
args.trust_remote_code, args.trust_remote_code,
args.max_num_seqs, args.max_num_seqs,
args.max_model_len,
args.gpu_memory_utilization, args.gpu_memory_utilization,
args.quantization,
), ),
) )
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0
for proc in procs: for proc in procs:
proc.join(timeout=300) proc.join(timeout=args.timeout)
if proc.exitcode is None: if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill() proc.kill()
......
...@@ -7,41 +7,22 @@ import torch ...@@ -7,41 +7,22 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo from .parallel_utils import ProcessGroupInfo
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
make_quant_fp8_weights, per_token_cast_to_fp8)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
...@@ -69,24 +50,31 @@ class Config: ...@@ -69,24 +50,31 @@ class Config:
torch_trace_dir_path: Optional[str] = None torch_trace_dir_path: Optional[str] = None
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
def describe(self) -> str: def describe(self) -> str:
s = "" s = ""
s += "== Config: \n" s += "== Config:\n"
s += f" world_size={self.world_size} \n" s += f" world_size={self.world_size}\n"
s += f" PF={self.prepare_finalize_type.__name__} \n" s += f" PF={self.prepare_finalize_type.__name__}\n"
s += f" FE={self.fused_experts_type.__name__} \n" s += f" FE={self.fused_experts_type.__name__}\n"
s += f" topk={self.topks} \n" s += f" E={self.E}\n"
s += f" dtype={self.dtype} \n" s += f" Ms={self.Ms}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" s += f" N={self.N}\n"
s += " Quant: \n" s += f" K={self.K}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n"
if self.quant_config is not None: if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype} \n" s += f" q_dtype={self.quant_dtype}\n"
s += f" q_block_shape={self.quant_block_shape} \n" s += f" q_block_shape={self.quant_block_shape}\n"
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
s += f" q_per_act_token={self.is_per_act_token_quant} \n" s += f" q_per_act_token={self.is_per_act_token_quant}\n"
else: else:
s += " quant=None \n" s += " quant=None\n"
return s return s
@property @property
...@@ -95,34 +83,28 @@ class Config: ...@@ -95,34 +83,28 @@ class Config:
return self.Ms return self.Ms
@property @property
def quant_dtype(self) -> Optional[torch.dtype]: def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is None: assert self.quant_config is not None
return None
return self.quant_config.quant_dtype return self.quant_config.quant_dtype
@property @property
def is_per_act_token_quant(self) -> bool: def is_per_act_token_quant(self) -> bool:
if self.quant_config is None: assert self.quant_config is not None
return False
return self.quant_config.per_act_token_quant return self.quant_config.per_act_token_quant
@property @property
def is_per_tensor_act_quant(self) -> bool: def is_per_tensor_act_quant(self) -> bool:
if self.quant_config is None:
return False
return (not self.is_per_act_token_quant return (not self.is_per_act_token_quant
and self.quant_block_shape is None) and self.quant_block_shape is None)
@property @property
def is_per_out_ch_quant(self) -> bool: def is_per_out_ch_quant(self) -> bool:
if self.quant_config is None: assert self.quant_config is not None
return False
return self.quant_config.per_out_ch_quant return self.quant_config.per_out_ch_quant
@property @property
def quant_block_shape(self) -> Optional[list[int]]: def quant_block_shape(self) -> Optional[list[int]]:
if self.quant_config is None: assert self.quant_config is not None
return None
return self.quant_config.block_shape return self.quant_config.block_shape
@property @property
...@@ -130,17 +112,6 @@ class Config: ...@@ -130,17 +112,6 @@ class Config:
assert isinstance(self.topks, int) assert isinstance(self.topks, int)
return self.topks return self.topks
@property
def topk_ids_dtype(self) -> Optional[torch.dtype]:
topk_ids_dtype = None
if self.prepare_finalize_type == PplxPrepareAndFinalize:
topk_ids_dtype = torch.uint32
elif self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]:
topk_ids_dtype = torch.int64
return topk_ids_dtype
@property @property
def num_local_experts(self) -> int: def num_local_experts(self) -> int:
return self.E // self.world_size return self.E // self.world_size
...@@ -154,12 +125,17 @@ class Config: ...@@ -154,12 +125,17 @@ class Config:
vllm_config.parallel_config.enable_expert_parallel = True vllm_config.parallel_config.enable_expert_parallel = True
env_dict = { env_dict = {
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
} }
backend = self.all2all_backend()
if backend is not None:
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
if self.fused_moe_chunk_size is not None: if self.fused_moe_chunk_size is not None:
env_dict.update( env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
return vllm_config, env_dict return vllm_config, env_dict
def is_fp8_block_quantized(self): def is_fp8_block_quantized(self):
...@@ -167,85 +143,59 @@ class Config: ...@@ -167,85 +143,59 @@ class Config:
and self.quant_block_shape is not None) and self.quant_block_shape is not None)
def is_batched_prepare_finalize(self): def is_batched_prepare_finalize(self):
return self.prepare_finalize_type in [ info = prepare_finalize_info(self.prepare_finalize_type)
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize return (mk.FusedMoEActivationFormat.BatchedExperts ==
] info.activation_format)
def is_batched_fused_experts(self): def is_batched_fused_experts(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, return (mk.FusedMoEActivationFormat.BatchedExperts ==
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts info.activation_format)
]
def is_standard_fused_experts(self): def is_standard_fused_experts(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, return mk.FusedMoEActivationFormat.Standard == info.activation_format
TritonExperts
] def fe_supported_types(self):
info = expert_info(self.fused_experts_type)
def is_fe_16bit_supported(self): return info.supported_dtypes
return self.fused_experts_type in [
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, def pf_supported_types(self):
NaiveBatchedExperts, TritonExperts info = prepare_finalize_info(self.prepare_finalize_type)
] return info.supported_dtypes
def is_fe_fp8_supported(self): def is_block_quant_supported(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
BatchedDeepGemmExperts, return info.blocked_quantization_support
BatchedTritonExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
NaiveBatchedExperts,
]
def is_fe_block_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonOrDeepGemmExperts,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
]
def is_fe_supports_chunking(self): def is_fe_supports_chunking(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, return info.supports_chunking
TritonExperts
] def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
def supports_apply_weight_on_input(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supports_apply_weight_on_input
def needs_deep_gemm(self): def needs_deep_gemm(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
BatchedDeepGemmExperts, return info.needs_deep_gemm
DeepGemmExperts,
]
def needs_pplx(self): def needs_pplx(self):
return self.prepare_finalize_type in [PplxPrepareAndFinalize] info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend == "pplx"
def needs_deep_ep(self): def needs_deep_ep(self):
return self.prepare_finalize_type in [ info = prepare_finalize_info(self.prepare_finalize_type)
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize return (info.backend == "deepep_high_throughput"
] or info.backend == "deepep_low_latency")
def all2all_backend(self): def all2all_backend(self):
if self.needs_pplx(): info = prepare_finalize_info(self.prepare_finalize_type)
return "pplx" return info.backend
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
return "deepep_high_throughput"
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
return "deepep_low_latency"
return "naive"
def needs_all2all(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize
]
def is_valid(self): def is_valid(self):
# Check prepare-finalize and fused-experts compatibility # Check prepare-finalize and fused-experts compatibility
...@@ -267,28 +217,28 @@ class Config: ...@@ -267,28 +217,28 @@ class Config:
# invalid quant config # invalid quant config
return False return False
# check bf16 / fp16 support # check type support
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) if self.quant_dtype is None:
if is_16bit and not self.is_fe_16bit_supported(): if (self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()):
return False return False
else:
# Check fp8 support if (self.quant_dtype not in self.pf_supported_types()
is_fp8 = self.quant_dtype == torch.float8_e4m3fn or self.quant_dtype not in self.fe_supported_types()):
if is_fp8 and not self.is_fe_fp8_supported():
return False return False
# Check fp8 block quanization support # Check block quanization support
is_block_quatized = self.quant_block_shape is not None is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and not is_fp8: if is_block_quatized and self.quant_dtype is None:
return False return False
if is_block_quatized and not self.is_fe_block_fp8_supported(): if is_block_quatized and not self.is_block_quant_supported():
return False return False
# deep_gemm only works with block-quantized # deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized: if self.needs_deep_gemm() and not is_block_quatized:
return False return False
# Check dependencies # Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep(): if self.needs_deep_ep() and not has_deep_ep():
return False return False
if self.needs_deep_gemm() and not has_deep_gemm(): if self.needs_deep_gemm() and not has_deep_gemm():
...@@ -305,6 +255,8 @@ class WeightTensors: ...@@ -305,6 +255,8 @@ class WeightTensors:
w2: torch.Tensor w2: torch.Tensor
w1_scale: Optional[torch.Tensor] w1_scale: Optional[torch.Tensor]
w2_scale: Optional[torch.Tensor] w2_scale: Optional[torch.Tensor]
w1_gs: Optional[torch.Tensor] = None
w2_gs: Optional[torch.Tensor] = None
def describe(self): def describe(self):
s = "" s = ""
...@@ -313,13 +265,20 @@ class WeightTensors: ...@@ -313,13 +265,20 @@ class WeightTensors:
s += f' - {_describe_tensor(self.w2, "w2")} \n' s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
return s return s
def is_quantized(self) -> bool:
# or w1_scale is not None?
return (self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self): def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device()) self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device()) self.w2 = self.w2.to(device=torch.cuda.current_device())
is_quantized = self.w1.dtype == torch.float8_e4m3fn
if is_quantized: if self.is_quantized():
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to( self.w1_scale = self.w1_scale.to(
...@@ -327,56 +286,51 @@ class WeightTensors: ...@@ -327,56 +286,51 @@ class WeightTensors:
self.w2_scale = self.w2_scale.to( self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device()) device=torch.cuda.current_device())
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
def slice_weights(self, rank: int, def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors": num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts s = rank * num_local_experts
e = s + num_local_experts e = s + num_local_experts
w1 = self.w1[s:e, :, :] w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :] w2 = self.w2[s:e, :, :]
is_quantized = self.w1.dtype == torch.float8_e4m3fn
w1_scale, w2_scale = (None, None) w1_scale, w2_scale = (None, None)
if is_quantized: if self.is_quantized():
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :] w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :] w2_scale = self.w2_scale[s:e, :, :]
return WeightTensors(w1, w2, w1_scale, w2_scale)
@staticmethod w1_gs = self.w1_gs
def make(config: Config) -> "WeightTensors": w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
if config.quant_dtype is None: return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
# just make normal dtype weights
w1, w2 = make_non_quant_weights(e=config.E,
n=config.N,
k=config.K,
dtype=config.dtype)
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
assert config.quant_dtype == torch.float8_e4m3fn @staticmethod
if not config.is_fp8_block_quantized(): def make(config: Config) -> "WeightTensors":
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights( (_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
e=config.E,
n=config.N,
k=config.K,
per_out_channel_quant=config.is_per_out_ch_quant,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
assert config.quant_block_shape is not None
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
e=config.E, e=config.E,
n=config.N, n=config.N,
k=config.K, k=config.K,
block_size=config.quant_block_shape, in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
) )
return WeightTensors(w1=w1, return WeightTensors(w1=w1,
w2=w2, w2=w2,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale) w2_scale=w2_scale,
w1_gs=w1_gs,
w2_gs=w2_gs)
@dataclass @dataclass
...@@ -449,7 +403,6 @@ class RankTensors: ...@@ -449,7 +403,6 @@ class RankTensors:
dtype=dtype) dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False) False)
topk_ids = topk_ids.to(config.topk_ids_dtype)
# distribute topk_ids evenly # distribute topk_ids evenly
for mi in range(m): for mi in range(m):
...@@ -457,7 +410,7 @@ class RankTensors: ...@@ -457,7 +410,7 @@ class RankTensors:
topk_ids = topk_ids.to(device=torch.cuda.current_device()) topk_ids = topk_ids.to(device=torch.cuda.current_device())
expert_map = None expert_map = None
if config.world_size > 1: if config.world_size > 1 and config.supports_expert_map():
expert_map = torch.full((global_num_experts, ), expert_map = torch.full((global_num_experts, ),
fill_value=-1, fill_value=-1,
dtype=torch.int32) dtype=torch.int32)
...@@ -480,92 +433,100 @@ class RankTensors: ...@@ -480,92 +433,100 @@ class RankTensors:
def reference_moe_impl(config: Config, weights: WeightTensors, def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor: rank_tensors: RankTensors) -> torch.Tensor:
return torch_experts(a=rank_tensors.hidden_states, if config.quant_dtype == "nvfp4":
w1=weights.w1, quant_blocksize = 16
w2=weights.w2, dtype = config.dtype
w1_q = weights.w1
w1_blockscale = weights.w1_scale
w1_gs = weights.w1_gs
w2_q = weights.w2
w2_blockscale = weights.w2_scale
w2_gs = weights.w2_gs
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
assert w1_blockscale.shape[1] % 128 == 0
assert w1_blockscale.shape[2] % 4 == 0
assert w2_blockscale.shape[1] % 128 == 0
assert w2_blockscale.shape[2] % 4 == 0
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
rank_tensors.hidden_states, a_global_scale)
a = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize)
e = w1_q.shape[0]
n = w1_q.shape[1] // 2
k = w2_q.shape[1]
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
a_scale = None
w1_scale = None
w2_scale = None
quant_dtype = None
per_act_token_quant = False
block_shape = None
else:
a = rank_tensors.hidden_states
a_scale = rank_tensors.hidden_states_scale
w1 = weights.w1
w1_scale = weights.w1_scale
w2 = weights.w2
w2_scale = weights.w2_scale
quant_dtype = config.quant_dtype
per_act_token_quant = config.is_per_act_token_quant
block_shape = config.quant_block_shape
return torch_experts(a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights, topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids, topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E, global_num_experts=config.E,
expert_map=None, expert_map=None,
w1_scale=weights.w1_scale, w1_scale=w1_scale,
w2_scale=weights.w2_scale, w2_scale=w2_scale,
a1_scale=rank_tensors.hidden_states_scale, a1_scale=a_scale,
quant_dtype=config.quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=config.is_per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=config.quant_block_shape, block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1) apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input())
def make_fused_experts(
config: Config, moe: FusedMoEConfig,
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if config.fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif config.fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif config.fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif config.fused_experts_type == CutlassExpertsFp8:
use_batched_format = config.is_batched_prepare_finalize()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
kwargs = {
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": config.is_per_act_token_quant,
"per_out_ch_quant": config.is_per_out_ch_quant,
"block_shape": config.quant_block_shape,
"num_dispatchers": num_dispatchers,
"use_batched_format": use_batched_format
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
return experts
def make_modular_kernel(config: Config, def make_modular_kernel(
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x): def next_power_of_2(x):
import math import math
...@@ -579,6 +540,7 @@ def make_modular_kernel(config: Config, ...@@ -579,6 +540,7 @@ def make_modular_kernel(config: Config,
dp_size_=get_dp_group().world_size, dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config, vllm_parallel_config=vllm_config.parallel_config,
) )
moe = FusedMoEConfig( moe = FusedMoEConfig(
num_experts=config.E, num_experts=config.E,
experts_per_token=config.topk, experts_per_token=config.topk,
...@@ -591,15 +553,16 @@ def make_modular_kernel(config: Config, ...@@ -591,15 +553,16 @@ def make_modular_kernel(config: Config,
) )
# make modular kernel # make modular kernel
prepare_finalize = None prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
if config.needs_all2all(): config.all2all_backend(), moe)
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None fused_experts = make_fused_experts(
else: config.fused_experts_type,
prepare_finalize = MoEPrepareAndFinalizeNoEP() moe,
prepare_finalize.num_dispatchers(),
fused_experts = make_fused_experts(config, moe, weights.w1_gs,
prepare_finalize.num_dispatchers()) weights.w2_gs,
)
modular_kernel = mk.FusedMoEModularKernel( modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts) prepare_finalize=prepare_finalize, fused_experts=fused_experts)
...@@ -620,22 +583,45 @@ def run_modular_kernel( ...@@ -620,22 +583,45 @@ def run_modular_kernel(
# weights for rank # weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config) mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = { mk_kwargs = {
"hidden_states": rank_tensors.hidden_states.clone( "hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place ), # impls might update the tensor in place
"w1": rank_weights.w1, "w1":
"w2": rank_weights.w2, rank_weights.w1,
"topk_weights": rank_tensors.topk_weights, "w2":
"topk_ids": rank_tensors.topk_ids, rank_weights.w2,
"expert_map": rank_tensors.expert_map, "topk_weights":
"w1_scale": rank_weights.w1_scale, rank_tensors.topk_weights,
"w2_scale": rank_weights.w2_scale, "topk_ids":
"a1_scale": rank_tensors.hidden_states_scale, rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
"global_num_experts": config.E, "expert_map":
"apply_router_weight_on_input": config.topk == 1, rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":
config.topk == 1 and config.supports_apply_weight_on_input(),
} }
num_tokens = rank_tensors.hidden_states.shape[0]
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
device="cuda",
dtype=torch.int)
with set_forward_context(
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs) out = mk.forward(**mk_kwargs)
return out return out
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch import torch
# Fused experts and PrepareFinalize imports # Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts) BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts) BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_pplx from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if has_deep_ep():
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
backend: Optional[str]
supports_apply_weight_on_input: bool = True
@dataclass
class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
def register_prepare_and_finalize(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
backend: Optional[str],
force_multigpu: bool = False,
supports_apply_weight_on_input: bool = True,
):
global PREPARE_FINALIZE_INFO
global MK_ALL_PREPARE_FINALIZE_TYPES
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
assert kind not in PREPARE_FINALIZE_INFO
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
backend,
supports_apply_weight_on_input,
)
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
if backend is not None or force_multigpu:
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
else:
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
def register_experts(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
):
global EXPERT_INFO
global MK_FUSED_EXPERT_TYPES
assert kind not in EXPERT_INFO
EXPERT_INFO[kind] = ExpertInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_chunking,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
)
MK_FUSED_EXPERT_TYPES.append(kind)
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
info = PREPARE_FINALIZE_INFO.get(kind)
assert info is not None
return info
def expert_info(kind) -> ExpertInfo:
info = EXPERT_INFO.get(kind)
assert info is not None
return info
register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend=None,
)
register_experts(
BatchedTritonExperts,
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
)
register_experts(
TritonExperts,
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
)
register_experts(
NaiveBatchedExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True,
)
# Disable on blackwell for now
if has_deep_ep() and not current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_high_throughput",
)
register_prepare_and_finalize(
DeepEPLLPrepareAndFinalize,
batched_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_low_latency",
)
if has_pplx(): if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
backend="pplx",
)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] if (has_flashinfer_cutlass_fused_moe()
if has_pplx(): and current_platform.has_device_capability(100)):
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
if has_deep_ep(): FlashInferExperts)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize FlashInferCutlassMoEPrepareAndFinalize)
]
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
supports_apply_weight_on_input=False,
)
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + register_experts(
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) FlashInferExperts,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
else:
FlashInferCutlassMoEPrepareAndFinalize = None
MK_FUSED_EXPERT_TYPES = [ if has_deep_gemm() and is_deep_gemm_supported():
register_experts(
BatchedDeepGemmExperts, BatchedDeepGemmExperts,
BatchedTritonExperts, batched_format,
NaiveBatchedExperts, fp8_types,
BatchedTritonOrDeepGemmExperts, blocked_quantization_support=True,
CutlassExpertsFp8, supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
DeepGemmExperts, DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts, standard_format,
] common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
)
if cutlass_fp8_supported():
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
CutlassExpertsFp8)
register_experts(
CutlassExpertsFp8,
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False,
)
register_experts(
CutlassBatchedExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False,
)
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4)
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [ MK_QUANT_CONFIGS = [
None, None,
...@@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [ ...@@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [
# block-quantized weights and per-token activations # block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations # block-quantized weights and per-tensor activations
] ]
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
return t[s:e]
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"num_dispatchers": num_dispatchers,
}
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
return experts
...@@ -52,7 +52,7 @@ def profile_modular_kernel( ...@@ -52,7 +52,7 @@ def profile_modular_kernel(
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
# make modular kernel # make modular kernel
mk = make_modular_kernel(config, vllm_config) mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = { mk_kwargs = {
"hidden_states": rank_tensors.hidden_states, "hidden_states": rank_tensors.hidden_states,
...@@ -83,7 +83,7 @@ def rank_worker( ...@@ -83,7 +83,7 @@ def rank_worker(
# sanity check # sanity check
from vllm import envs from vllm import envs
if config.fused_moe_chunk_size is not None: if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm._custom_ops as ops
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_non_quant_weights(
e: int,
n: int,
k: int,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2
"""
device = torch.cuda.current_device()
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
return w1, w2
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
dtype = torch.bfloat16
device = torch.cuda.current_device()
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device=device,
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device=device,
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
(k + (block_k - 1)) // block_k)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=[block_k, block_n])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=[block_k, block_n])
return w1, w2, w1_s, w2_s
def make_quant_fp8_weights(
e: int,
n: int,
k: int,
per_out_channel_quant: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return w1, w2, w1_scale, w2_scale
"""
q_dtype = torch.float8_e4m3fn
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
# w1 -> w1_q, w2 -> w2_q
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
n_b_scales = 2 * n if per_out_channel_quant else 1
k_b_scales = k if per_out_channel_quant else 1
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
return w1_q, w2_q, w1_scale, w2_scale
...@@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ...@@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
) )
B, B_q, B_scale, _, _, _ = make_test_weights( (B, B_q, B_scale, _), _ = make_test_weights(
num_experts, num_experts,
N // 2, N // 2,
K, K,
...@@ -243,7 +243,7 @@ def test_fused_moe_batched_experts( ...@@ -243,7 +243,7 @@ def test_fused_moe_batched_experts(
act_dtype = dtype act_dtype = dtype
quant_dtype = None quant_dtype = None
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
......
...@@ -161,7 +161,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, ...@@ -161,7 +161,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N, N,
K, K,
dtype, dtype,
...@@ -173,6 +174,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, ...@@ -173,6 +174,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=block_size) block_shape=block_size)
...@@ -247,7 +249,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, ...@@ -247,7 +249,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N, N,
K, K,
dtype, dtype,
......
...@@ -118,7 +118,8 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -118,7 +118,8 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N, N,
K, K,
dtype, dtype,
......
...@@ -9,6 +9,7 @@ import random ...@@ -9,6 +9,7 @@ import random
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import per_token_cast_to_fp8
from tests.kernels.utils import baseline_scaled_mm from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -16,20 +17,6 @@ from vllm.utils import cdiv ...@@ -16,20 +17,6 @@ from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view *
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096), (4, 8192, 7168, 4096),
(4, 8192, 2048, 7168), (4, 8192, 2048, 7168),
...@@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm( ...@@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm(
device=device, device=device,
dtype=torch.float)) dtype=torch.float))
for i in range(num_groups): for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups): for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
......
...@@ -70,8 +70,10 @@ def make_block_quant_fp8_weights( ...@@ -70,8 +70,10 @@ def make_block_quant_fp8_weights(
""" """
Return weights w1q, w2q, w1_scale, w2_scale Return weights w1q, w2q, w1_scale, w2_scale
""" """
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( (_, w1q, w1_scale, _), (_, w2q, w2_scale,
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) _) = make_test_weights(e, n, k, torch.bfloat16,
torch.float8_e4m3fn,
block_size)
return w1q, w2q, w1_scale, w2_scale return w1q, w2q, w1_scale, w2_scale
......
...@@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
# Note: W1 has shape (E, 2N, K), so N = 512 # Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path. # can trigger the deepgemm path.
MNKs = [ MNKs = [
(1024, 512, 128), (1024, 768, 128),
(1024, 512, 512), (1024, 768, 512),
(2048, 512, 512), (2048, 768, 512),
(512, 1024, 1024), (512, 1024, 1024),
(512, 2048, 2048), (512, 2048, 2048),
(4096, 4096, 1024), (4096, 4096, 1024),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
a1_gscale=a1_gs,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
flashinfer_output,
atol=1e-1,
rtol=1e-1)
if __name__ == "__main__":
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
import textwrap
import traceback
from itertools import product from itertools import product
from typing import Optional from typing import Optional
...@@ -10,41 +12,51 @@ import torch ...@@ -10,41 +12,51 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl, reference_moe_impl,
run_modular_kernel) run_modular_kernel)
from .modular_kernel_tools.mk_objects import ( from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config) parallel_launch_with_config)
# TODO (varun): These requirements are very strict and could be relaxed. has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) or has_flashinfer_cutlass_fused_moe())
meets_package_requirements = pytest.mark.skipif( meets_multi_gpu_requirements = pytest.mark.skipif(
not has_all_packages, not has_any_multi_gpu_package,
reason="Requires deep_ep & deep_gemm & pplx packages", reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
) )
def format_result(verbose, msg, ex=None):
if ex is not None:
x = str(ex)
newx = x.strip(" \n\t")[:16]
if len(newx) < len(x):
newx = newx + " ..."
prefix = "E\t"
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
print(f"FAILED {msg} - {newx}\n")
elif verbose:
print(f"PASSED {msg}")
else:
print(".", end="")
def rank_worker( def rank_worker(
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
vllm_config: VllmConfig, vllm_config: VllmConfig,
cpu_group, cpu_group,
config: Config, config: Config,
weights: WeightTensors, weights: WeightTensors,
verbose: bool,
): ):
current_platform.seed_everything(pgi.rank) current_platform.seed_everything(pgi.rank)
...@@ -61,8 +73,13 @@ def rank_worker( ...@@ -61,8 +73,13 @@ def rank_worker(
TOPKs = config.topks TOPKs = config.topks
assert isinstance(TOPKs, list) assert isinstance(TOPKs, list)
exceptions = []
count = 0
for m, topk in product(Ms, TOPKs): for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...") try:
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
count = count + 1
# override m and topk # override m and topk
cfgx = copy.deepcopy(config) cfgx = copy.deepcopy(config)
cfgx.Ms = m cfgx.Ms = m
...@@ -78,22 +95,42 @@ def rank_worker( ...@@ -78,22 +95,42 @@ def rank_worker(
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors) ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) if config.quant_dtype == "nvfp4":
atol = 1e-1
rtol = 1e-1
def run(config: Config): else:
atol = 3e-2
rtol = 3e-2
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
format_result(verbose, config.describe())
except Exception as ex:
format_result(verbose, config.describe(), ex)
exceptions.append(ex)
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
def run(config: Config, verbose: bool):
assert config.is_valid() assert config.is_valid()
print(f"Testing config \n{config.describe()} ...")
weights: WeightTensors = WeightTensors.make(config) weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data() vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config, parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights) env_dict, config, weights, verbose)
Ms = [32, 64] Ms = [32, 64]
Ks = [7168] # hidden sizes # hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
Ns = [2048] Ns = [2048]
TOPKs = [4, 1] TOPKs = [4, 1]
Es = [32] Es = [32]
...@@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16] ...@@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
def is_nyi_config(config: Config) -> bool: def is_nyi_config(config: Config) -> bool:
# We know these configs to be legitimate. but still fail. # We know these configs to be legitimate. but still fail.
info = expert_info(config.fused_experts_type)
if (config.fused_experts_type in [ if info.needs_matching_quant:
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
TritonExperts, TritonOrDeepGemmExperts
]):
# The triton kernels expect both per-act-token-quant and # The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither. # per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant + unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1) config.is_per_out_ch_quant) == 1)
return unsupported_quant_config return unsupported_quant_config
# cutlass kernels dont support expert_maps yet. return not info.supports_expert_map
return config.fused_experts_type == CutlassExpertsFp8
@pytest.mark.parametrize("k", Ks) @pytest.mark.parametrize("k", Ks)
...@@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool: ...@@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool:
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@meets_package_requirements @meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu( def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config( config = Config(
Ms=Ms, Ms=Ms,
...@@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu( ...@@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu(
fused_moe_chunk_size=fused_moe_chunk_size, fused_moe_chunk_size=fused_moe_chunk_size,
world_size=world_size, world_size=world_size,
) )
if not config.is_valid(): if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...") pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config): if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...") pytest.skip(f"Tests config {config} is nyi. Skipping ...")
print(f"{config.describe()}") verbosity = pytestconfig.getoption('verbose')
run(config) run(config, verbosity > 0)
@pytest.mark.parametrize("k", Ks) @pytest.mark.parametrize("k", Ks)
...@@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu( ...@@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu(
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@meets_package_requirements
def test_modular_kernel_combinations_singlegpu( def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config( config = Config(
Ms=Ms, Ms=Ms,
K=k, K=k,
...@@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu(
if is_nyi_config(config): if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...") pytest.skip(f"Tests config {config} is nyi. Skipping ...")
run(config) verbosity = pytestconfig.getoption('verbose')
run(config, verbosity > 0)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -211,4 +246,4 @@ if __name__ == '__main__': ...@@ -211,4 +246,4 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
config = make_config(args) config = make_config(args)
run(config) run(config, True)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype) dequantize_nvfp4_to_dtype)
...@@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
VllmConfig(parallel_config=ParallelConfig( VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))): pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16 quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), (_, w1_q, w1_blockscale,
device="cuda", w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
dtype=torch.float8_e4m3fn) e,
n,
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 k,
sf_w2_k = round_up(k, 128) in_dtype=dtype,
sf_w2_n = round_up(n // quant_blocksize, 4) quant_dtype="nvfp4",
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), block_shape=None, # use quant_blocksize?
device="cuda", per_act_token_quant=False,
dtype=torch.float8_e4m3fn) )
w1_q = torch.empty((e, 2 * n, k // 2),
device="cuda",
dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, topk_weights, topk_ids, _ = fused_topk(a,
...@@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
cutlass_output = cutlass_moe_fp4( cutlass_output = cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs, a1_gscale=a1_gs,
...@@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
n=n, n=n,
k=k, k=k,
e=e, e=e,
device=a.device,
) )
# Reference check: # Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32) torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved, a_scale_interleaved,
a_global_scale, a_global_scale,
...@@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx], w1_blockscale[idx],
w1_gs[idx], w1_gs[idx],
dtype=w1.dtype, dtype=dtype,
device=w1.device, device=w1_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx], w2_blockscale[idx],
w2_gs[idx], w2_gs[idx],
dtype=w2.dtype, dtype=dtype,
device=w2.device, device=w2_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
......
...@@ -9,7 +9,8 @@ import torch ...@@ -9,7 +9,8 @@ import torch
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
...@@ -123,12 +124,8 @@ def pplx_cutlass_moe( ...@@ -123,12 +124,8 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers) num_dispatchers=num_dispatchers)
experts = CutlassExpertsFp8(num_local_experts, experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, out_dtype, per_act_token, per_out_ch)
per_act_token,
per_out_ch,
num_dispatchers=num_dispatchers,
use_batched_format=True)
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
......
...@@ -770,7 +770,7 @@ def test_pplx_moe_slow( ...@@ -770,7 +770,7 @@ def test_pplx_moe_slow(
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights( (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
...@@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, ...@@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
args = dict() args = dict()
if make_weights: if make_weights:
_, w1, w1_s, _, w2, w2_s = make_test_weights( (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Union
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
...@@ -169,28 +171,41 @@ def make_quantized_test_activations( ...@@ -169,28 +171,41 @@ def make_quantized_test_activations(
def moe_quantize_weights( def moe_quantize_weights(
w: torch.Tensor, w: torch.Tensor,
w_s: Optional[torch.Tensor], w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype], quant_dtype: Union[torch.dtype, str, None],
per_token_quant: bool, per_token_quant: bool,
block_shape: Optional[list[int]], block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
or quant_dtype == torch.int8), "only fp8/int8 supported" or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
w_gs = None
if block_shape is not None: if block_shape is not None:
assert not per_token_quant assert not per_token_quant
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape) w, w_s = per_block_cast_to_int8(w, block_shape)
else: elif quant_dtype == torch.float8_e4m3fn:
w, w_s = per_block_cast_to_fp8(w, block_shape) w, w_s = per_block_cast_to_fp8(w, block_shape)
elif quant_dtype == "nvfp4":
raise RuntimeError("blocked quantization not supported for nvfp4")
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
else: else:
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant( w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant) w, w_s, use_per_token_if_dynamic=per_token_quant)
else: elif quant_dtype == torch.float8_e4m3fn:
w, w_s = ops.scaled_fp8_quant( w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant) w, w_s, use_per_token_if_dynamic=per_token_quant)
elif quant_dtype == "nvfp4":
assert not per_token_quant
w_amax = torch.abs(w).max().to(torch.float32)
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
w, w_s = ops.scaled_fp4_quant(w, w_gs)
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
return w, w_s return w, w_s, w_gs
def make_test_weight( def make_test_weight(
...@@ -198,21 +213,26 @@ def make_test_weight( ...@@ -198,21 +213,26 @@ def make_test_weight(
rows: int, rows: int,
cols: int, cols: int,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
if quant_dtype is not None: if quant_dtype is not None:
w_l = [None] * e w_l = [None] * e
w_s_l = [None] * e w_s_l = [None] * e
w_gs_l = [None] * e
for idx in range(e): for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights( w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l) w = torch.stack(w_l)
w_s = torch.stack(w_s_l) w_s = torch.stack(w_s_l)
if e > 0 and w_gs_l[0] is not None:
w_gs = torch.stack(w_gs_l)
if w_s.ndim == 2: if w_s.ndim == 2:
assert w_s.shape[-1] == 1 assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1) w_s = w_s.view(-1, 1, 1)
...@@ -225,8 +245,9 @@ def make_test_weight( ...@@ -225,8 +245,9 @@ def make_test_weight(
else: else:
w = w_16 w = w_16
w_s = None w_s = None
w_gs = None
return w_16, w, w_s return w_16, w, w_s, w_gs
def make_test_weights( def make_test_weights(
...@@ -234,14 +255,30 @@ def make_test_weights( ...@@ -234,14 +255,30 @@ def make_test_weights(
n: int, n: int,
k: int, k: int,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
torch.Tensor, Optional[torch.Tensor]]: Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return ( return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_act_token_quant),
) )
def per_token_cast_to_fp8(
x: torch.Tensor,
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
...@@ -105,7 +105,8 @@ class DeviceCommunicatorBase: ...@@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
# we initialize the all2all manager used in expert parallel. # we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1 use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
...@@ -246,7 +247,7 @@ class DeviceCommunicatorBase: ...@@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
""" """
Prepare the communication buffer for the model. Prepare the communication buffer for the model.
""" """
if not self.use_all2all: if not self.is_ep_communicator:
return return
moe_modules = [ moe_modules = [
...@@ -254,7 +255,7 @@ class DeviceCommunicatorBase: ...@@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE" if module.__class__.__name__ == "FusedMoE"
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config) module.quant_method.init_prepare_finalize()
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment