Commit 899a2db4 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1(ex fused_moe&models)

parent 78c1f9e5
...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -52,11 +54,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -52,11 +54,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe, flashinfer_trtllm_fp4_routed_moe,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -80,9 +84,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -80,9 +84,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym, kFp8StaticTokenSym,
kNvfp4Dynamic, kNvfp4Dynamic,
kNvfp4Static, kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...@@ -735,23 +736,47 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -735,23 +736,47 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym, activation_key=kFp8StaticTensorSym,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( # TRT LLM not supported with all2all yet.
f"{self.__class__.__name__} uses the new modular kernel initialization " if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
"logic. This function should not be called." return None
) elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def create_weights( def create_weights(
...@@ -835,7 +860,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -835,7 +860,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: FusedMoE, layer: torch.nn.Module,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -865,13 +890,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -865,13 +890,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -972,8 +995,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -972,8 +995,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}" f"but got {layer.activation}"
) )
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1257,16 +1280,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1257,16 +1280,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
else: else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
def apply( def apply(
self, self,
...@@ -1288,6 +1304,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1288,6 +1304,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
) )
output_dtype = x.dtype output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
...@@ -1302,12 +1319,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1302,12 +1319,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32 assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = ( mm_args = (
x_fp4, x_fp4,
layer.weight, layer.weight,
...@@ -1316,7 +1327,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1316,7 +1327,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha, layer.alpha,
output_dtype, output_dtype,
) )
if self.backend.startswith("flashinfer-"): if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :] backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
...@@ -1324,9 +1334,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1324,9 +1334,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert self.backend == "cutlass" assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args) out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
...@@ -1353,27 +1360,50 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1353,27 +1360,50 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic, activation_key=kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError( if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
f"{self.__class__.__name__} uses the new modular kernel initialization " return None
"logic. This function should not be called." elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
) # For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError( assert self.moe_quant_config is not None
f"{self.__class__.__name__} uses the new modular kernel initialization " assert self.experts_cls is not None
"logic. This function should not be called." return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
...@@ -1498,7 +1528,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1498,7 +1528,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -1550,14 +1580,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1550,14 +1580,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
@property @property
...@@ -1658,8 +1689,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1658,8 +1689,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.moe_mk is not None assert self.kernel is not None
return self.moe_mk( return self.kernel(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1675,4 +1706,4 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1675,4 +1706,4 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
\ No newline at end of file
...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16), router_logits.to(torch.bfloat16),
routing_bias=None, None, # routing_bias
hidden_states=x_quant, x_quant,
hidden_states_scale=x_scale, x_scale,
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2) layer.w13_weight, # uint8 (e2m1 x 2)
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2) layer.w13_weight_scale, # uint8 (e4m3 x 2)
gemm1_bias=layer.w13_bias, # fp32 per expert per channel layer.w13_bias, # fp32 per expert per channel
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert layer.gemm1_alpha, # fp32 per expert
gemm1_beta=layer.gemm1_beta, # fp32 per expert layer.gemm1_beta, # fp32 per expert
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert layer.gemm1_clamp_limit, # fp32 per expert
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2) layer.w2_weight, # uint8 (e2m1 x 2)
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0 layer.w2_weight_scale, # ue8m0
gemm2_bias=layer.w2_bias, # fp32 per expert per channel layer.w2_bias, # fp32 per expert per channel
output1_scale_scalar=None, None, # output1_scale_scalar
output1_scale_gate_scalar=None, None, # output1_scale_gate_scalar
output2_scale_scalar=None, None, # output2_scale_scalar
num_experts=layer.global_num_experts, layer.global_num_experts,
top_k=layer.top_k, layer.top_k,
n_group=None, None, # n_group
topk_group=None, None, # topk_group
intermediate_size=self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
local_expert_offset=layer.ep_rank * layer.local_num_experts, layer.ep_rank * layer.local_num_experts, # local_expert_offset
local_num_experts=self.num_experts, self.num_experts, # local num experts
routed_scaling_factor=None, None, # routed_scaling_factor
routing_method_type=1 if layer.renormalize else 0, 1 if layer.renormalize else 0, # routing_method_type, renormalize
do_finalize=True, True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
...@@ -1170,4 +1170,4 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1170,4 +1170,4 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
activation="swiglu_oai", activation="swiglu_oai",
) )
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
return hidden_states return hidden_states
\ No newline at end of file
...@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kNvfp4Dynamic, kNvfp4Dynamic,
...@@ -32,6 +35,7 @@ logger = init_logger(__name__) ...@@ -32,6 +35,7 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
# #
...@@ -132,6 +136,17 @@ def reorder_w1w3_to_w3w1( ...@@ -132,6 +136,17 @@ def reorder_w1w3_to_w3w1(
) )
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
...@@ -526,4 +541,4 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ...@@ -526,4 +541,4 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
w2_scale = swizzle_blockscale(w2_scale) w2_scale = swizzle_blockscale(w2_scale)
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
\ No newline at end of file
...@@ -4,8 +4,15 @@ from enum import Enum ...@@ -4,8 +4,15 @@ from enum import Enum
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -156,6 +163,18 @@ def make_fp8_moe_alpha_scales_for_fi( ...@@ -156,6 +163,18 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
...@@ -293,4 +312,4 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -293,4 +312,4 @@ def prepare_fp8_moe_layer_for_fi(
w2_input_scale=w2_input_scale, w2_input_scale=w2_input_scale,
) )
return w13, w2, w13_scale return w13, w2, w13_scale
\ No newline at end of file
...@@ -16,7 +16,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( ...@@ -16,7 +16,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols from .quant_utils import pack_cols, unpack_cols
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -675,66 +674,6 @@ def apply_awq_marlin_linear( ...@@ -675,66 +674,6 @@ def apply_awq_marlin_linear(
return output.reshape(out_shape) return output.reshape(out_shape)
def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)
def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor, def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor,
data_num_0: int, data_num_1: int) -> torch.Tensor: data_num_0: int, data_num_1: int) -> torch.Tensor:
""" """
......
...@@ -174,15 +174,15 @@ def get_rope( ...@@ -174,15 +174,15 @@ def get_rope(
scaling_factor = rope_parameters["factor"] scaling_factor = rope_parameters["factor"]
scaling_alpha = rope_parameters["alpha"] scaling_alpha = rope_parameters["alpha"]
if scaling_alpha: if scaling_alpha:
rotary_emb = DynamicNTKAlphaRotaryEmbedding( rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, base,
is_neox_style, is_neox_style,
scaling_alpha, scaling_alpha,
dtype, dtype,
) )
else: else:
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, head_size,
......
...@@ -6,8 +6,6 @@ from dataclasses import dataclass ...@@ -6,8 +6,6 @@ from dataclasses import dataclass
from vllm import envs from vllm import envs
import os import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
......
...@@ -37,12 +37,6 @@ def gemm_bank_conf(weight): ...@@ -37,12 +37,6 @@ def gemm_bank_conf(weight):
return True return True
else: else:
return False return False
def set_random_seed(seed: int | None) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
def set_weight_attrs( def set_weight_attrs(
......
...@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank ...@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
...@@ -168,10 +169,9 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ...@@ -168,10 +169,9 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# modular kernels could invoke deep_gemm_moe_fp8 # modular kernels could invoke deep_gemm_moe_fp8
return True return True
mk: FusedMoEModularKernel = module.quant_method.fused_experts
# Further check if the ModularKernel implementation uses the DeepGemmExperts # Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance( return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse import argparse
import gc
import json import json
import os import os
import time import time
...@@ -14,18 +15,65 @@ import ray ...@@ -14,18 +15,65 @@ import ray
import torch import torch
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
_get_config_dtype_str, _get_config_dtype_str,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import set_random_seed
# 移除全局的 current_platform 导入,改为在需要时局部导入 # 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype() # FP8_DTYPE = current_platform.fp8_dtype()
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))
def clear_triton_cache():
"""Clear Triton JIT compilation cache and Python/CUDA memory.
This helps prevent OOM during tuning with large models (many experts).
"""
# Force Python garbage collection
gc.collect()
# Clear CUDA memory cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Try to clear Triton's runtime cache
try:
if (
hasattr(triton, "runtime")
and hasattr(triton.runtime, "cache")
and hasattr(triton.runtime.cache, "clear")
):
triton.runtime.cache.clear()
except ImportError:
# Triton not installed, skip cache clearing
pass
except AttributeError:
# Triton version doesn't have expected cache API
pass
except Exception as e:
print(f"Warning: Failed to clear Triton cache: {e}")
# Additional garbage collection after clearing caches
gc.collect()
def ensure_divisibility(numerator, denominator, text): def ensure_divisibility(numerator, denominator, text):
"""Ensure that numerator is divisible by the denominator.""" """Ensure that numerator is divisible by the denominator."""
...@@ -195,10 +243,36 @@ def benchmark_config( ...@@ -195,10 +243,36 @@ def benchmark_config(
block_shape=block_quant_shape, block_shape=block_quant_shape,
) )
deep_gemm_experts = None
if use_deep_gemm:
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
activation="silu",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
device="cuda",
),
quant_config=quant_config,
),
)
with override_config(config): with override_config(config):
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm x, input_gating, topk, renormalize=not use_deep_gemm
) )
if use_deep_gemm:
return deep_gemm_experts(
x, w1, w2, topk_weights, topk_ids, inplace=True
)
return fused_experts( return fused_experts(
x, x,
w1, w1,
...@@ -207,7 +281,7 @@ def benchmark_config( ...@@ -207,7 +281,7 @@ def benchmark_config(
topk_ids, topk_ids,
inplace=True, inplace=True,
quant_config=quant_config, quant_config=quant_config,
allow_deep_gemm=use_deep_gemm, use_nn_moe=nn_moe,
) )
# JIT compilation & warmup # JIT compilation & warmup
...@@ -227,8 +301,8 @@ def benchmark_config( ...@@ -227,8 +301,8 @@ def benchmark_config(
# run() # run()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
...@@ -252,10 +326,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False): ...@@ -252,10 +326,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_k_range = [32, 64, 128, 256] block_k_range = [32, 64, 128, 256]
if not use_fp16: if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [2, 4, 8] num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 16, 32, 64] group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2, 3, 4, 5] num_stage_range = [2]
# waves_per_eu_range = [0] # waves_per_eu_range = [0, 1, 2, 4]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] # matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else [] # kpack_range = [1, 2] if use_fp16 else []
...@@ -453,7 +527,8 @@ class BenchmarkWorker: ...@@ -453,7 +527,8 @@ class BenchmarkWorker:
pass pass
else: else:
torch.set_default_device("cuda:"+ str(device_id)) torch.set_default_device("cuda:"+ str(device_id))
current_platform.seed_everything(seed)
set_random_seed(seed)
self.seed = seed self.seed = seed
# Store the logical device ID for Ray # Store the logical device ID for Ray
self.device_id = device_id self.device_id = device_id
...@@ -474,7 +549,10 @@ class BenchmarkWorker: ...@@ -474,7 +549,10 @@ class BenchmarkWorker:
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
# 局部导入 current_platform # 局部导入 current_platform
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.seed_everything(self.seed)
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_configs, get_default_config
set_random_seed(self.seed)
dtype_str = _get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
) )
...@@ -559,7 +637,7 @@ class BenchmarkWorker: ...@@ -559,7 +637,7 @@ class BenchmarkWorker:
need_device_guard = True need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space): for idx, config in enumerate(tqdm(search_space)):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(
config, config,
...@@ -582,6 +660,19 @@ class BenchmarkWorker: ...@@ -582,6 +660,19 @@ class BenchmarkWorker:
if kernel_time < best_time: if kernel_time < best_time:
best_time = kernel_time best_time = kernel_time
best_config = config best_config = config
# Periodically clear Triton JIT cache to prevent OOM
# This is especially important for large models with many experts
if (
TRITON_CACHE_CLEAR_INTERVAL > 0
and idx > 0
and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
):
clear_triton_cache()
# Final cleanup after tuning completes
clear_triton_cache()
now = datetime.now() now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None assert best_config is not None
...@@ -668,19 +759,24 @@ def main(args: argparse.Namespace): ...@@ -668,19 +759,24 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM", "DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM", "Glm4MoeForCausalLM",
"Glm4MoeLiteForCausalLM",
"NemotronHForCausalLM",
): ):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
...@@ -689,14 +785,27 @@ def main(args: argparse.Namespace): ...@@ -689,14 +785,27 @@ def main(args: argparse.Namespace):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts E = config.num_experts
topk = config.moe_topk[0] topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0] intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] in ("Step3VLForConditionalGeneration"): elif config.architectures[0] in ("Step3VLForConditionalGeneration"):
E = config.text_config.moe_num_experts E = config.text_config.moe_num_experts
topk = config.text_config.moe_top_k topk = config.text_config.moe_top_k
intermediate_size = config.text_config.moe_intermediate_size intermediate_size = config.text_config.moe_intermediate_size
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
else: else:
# Support for llama4 # Support for llama4
config = config.get_text_config() config = config.get_text_config()
...@@ -704,16 +813,16 @@ def main(args: argparse.Namespace): ...@@ -704,16 +813,16 @@ def main(args: argparse.Namespace):
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
enable_ep = bool(args.enable_expert_parallel) enable_ep = bool(args.enable_expert_parallel)
if enable_ep: if enable_ep:
ensure_divisibility(E, tp_size, "Number of experts") ensure_divisibility(E, tp_size, "Number of experts")
E = E // tp_size E = E // tp_size
shard_intermediate_size = 2 * intermediate_size shard_intermediate_size = 2 * intermediate_size
else: else:
ensure_divisibility(intermediate_size, tp_size, "intermediate_size") ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.dtype
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
......
...@@ -152,32 +152,32 @@ def use_rocm_custom_paged_attention( ...@@ -152,32 +152,32 @@ def use_rocm_custom_paged_attention(
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
if ON_GFX9: # if ON_GFX9:
return ( # return (
(sliding_window == 0 or sliding_window == (-1, -1)) # (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) # and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128) # and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32) # and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) # and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 # and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) # and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and sinks is None # and sinks is None
) # )
else: # else:
return ( # return (
ON_GFX11_GFX12 # ON_GFX11_GFX12
and (sliding_window == 0 or sliding_window == (-1, -1)) # and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) # and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 # and head_size == 128
and block_size == 16 # and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16) # and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 # and max_seq_len <= 128 * 1024
and alibi_slopes is None # and alibi_slopes is None
and kv_cache_dtype == "auto" # and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN # and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
and sinks is None # and sinks is None
) # )
return False return False
......
...@@ -16,20 +16,10 @@ class FilesystemResolver(LoRAResolver): ...@@ -16,20 +16,10 @@ class FilesystemResolver(LoRAResolver):
self, base_model_name: str, lora_name: str self, base_model_name: str, lora_name: str
) -> LoRARequest | None: ) -> LoRARequest | None:
lora_path = os.path.join(self.lora_cache_dir, lora_name) lora_path = os.path.join(self.lora_cache_dir, lora_name)
maybe_lora_request = await self._get_lora_req_from_path(
lora_name, lora_path, base_model_name
)
return maybe_lora_request
async def _get_lora_req_from_path(
self, lora_name: str, lora_path: str, base_model_name: str
) -> LoRARequest | None:
"""Builds a LoraRequest pointing to the lora path if it's a valid
LoRA adapter and has a matching base_model_name.
"""
if os.path.exists(lora_path): if os.path.exists(lora_path):
adapter_config_path = os.path.join(lora_path, "adapter_config.json") adapter_config_path = os.path.join(
self.lora_cache_dir, lora_name, "adapter_config.json"
)
if os.path.exists(adapter_config_path): if os.path.exists(adapter_config_path):
with open(adapter_config_path) as file: with open(adapter_config_path) as file:
adapter_config = json.load(file) adapter_config = json.load(file)
...@@ -59,4 +49,4 @@ def register_filesystem_resolver(): ...@@ -59,4 +49,4 @@ def register_filesystem_resolver():
fs_resolver = FilesystemResolver(lora_cache_dir) fs_resolver = FilesystemResolver(lora_cache_dir)
LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver) LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver)
return return
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from huggingface_hub import HfApi, snapshot_download
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolverRegistry
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
logger = init_logger(__name__)
class HfHubResolver(FilesystemResolver):
def __init__(self, repo_list: list[str]):
logger.warning(
"LoRA is allowing resolution from the following repositories on"
" HF Hub: %s please note that allowing remote downloads"
" is not secure, and that this plugin is not intended for use in"
" production environments.",
repo_list,
)
self.repo_list: list[str] = repo_list
self.adapter_dirs: dict[str, set[str]] = {}
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
"""Resolves potential LoRA requests in a remote repo on HF Hub.
This is effectively the same behavior as the filesystem resolver, but
with a snapshot_download on dirs containing an adapter config prior
to inspecting the cached dir to build a potential LoRA
request.
"""
# If a LoRA name begins with the repository name, it's disambiguated
maybe_repo = await self._resolve_repo(lora_name)
# If we haven't inspected this repo before, save available adapter dirs
if maybe_repo is not None and maybe_repo not in self.adapter_dirs:
self.adapter_dirs[maybe_repo] = await self._get_adapter_dirs(maybe_repo)
maybe_subpath = await self._resolve_repo_subpath(lora_name, maybe_repo)
if maybe_repo is None or maybe_subpath is None:
return None
repo_path = await asyncio.to_thread(
snapshot_download,
repo_id=maybe_repo,
allow_patterns=f"{maybe_subpath}/*" if maybe_subpath != "." else "*",
)
lora_path = os.path.join(repo_path, maybe_subpath)
maybe_lora_request = await self._get_lora_req_from_path(
lora_name, lora_path, base_model_name
)
return maybe_lora_request
async def _resolve_repo(self, lora_name: str) -> str | None:
"""Given a fully qualified path to a LoRA with respect to its HF Hub
repo, match the right repo to potentially download from if one exists.
Args:
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>,
match on <org>/<repo> (if it contains an adapter directly) or
<org>/<repo>/ if it may have one in subdirs.
"""
for potential_repo in self.repo_list:
if lora_name.startswith(potential_repo) and (
len(lora_name) == len(potential_repo)
or lora_name[len(potential_repo)] == "/"
):
return potential_repo
return None
async def _resolve_repo_subpath(
self, lora_name: str, maybe_repo: str | None
) -> str | None:
"""Given the fully qualified path of the LoRA with respect to the HF
Repo, get the subpath to download from assuming it's actually got an
adapter in it.
Args:
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>
maybe_repo: Path to the repo to match against if one exists.
"""
if maybe_repo is None:
return None
repo_len = len(maybe_repo)
if lora_name == maybe_repo or (
len(lora_name) == repo_len + 1 and lora_name[-1] == "/"
):
# Resolves to the root of the directory
adapter_dir = "."
else:
# It's a subpath; removing trailing slashes if there are any
adapter_dir = lora_name[repo_len + 1 :].rstrip("/")
# Only download if the directory actually contains an adapter
is_adapter = adapter_dir in self.adapter_dirs[maybe_repo]
return adapter_dir if is_adapter else None
async def _get_adapter_dirs(self, repo_name: str) -> set[str]:
"""Gets the subpaths within a HF repo that contain an adapter config.
Args:
repo_name: Name of the HF hub repo to inspect.
"""
repo_files = await asyncio.to_thread(HfApi().list_repo_files, repo_id=repo_name)
adapter_dirs = {
os.path.dirname(name)
for name in repo_files
if name.endswith("adapter_config.json")
}
if "adapter_config.json" in repo_files:
adapter_dirs.add(".")
return adapter_dirs
def register_hf_hub_resolver():
"""Register the Hf hub LoRA Resolver with vLLM"""
hf_repo_list = envs.VLLM_LORA_RESOLVER_HF_REPO_LIST
is_enabled = (
envs.VLLM_PLUGINS is not None and "lora_hf_hub_resolver" in envs.VLLM_PLUGINS
)
if hf_repo_list:
if not is_enabled:
logger.warning(
"It appears that VLLM_LORA_RESOLVER_HF_REPO_LIST is set, but "
"lora_hf_hub_resolver is not enabled in VLLM_PLUGINS; you must"
" enable this resolver directly in VLLM_PLUGINS to use it "
" because it allows remote downloads."
)
else:
hf_hub_resolver = HfHubResolver(hf_repo_list.split(","))
LoRAResolverRegistry.register_resolver("Hf Hub Resolver", hf_hub_resolver)
return
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
__all__ = [ __all__ = [
......
...@@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser): ...@@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser):
if current_tool_call_matches: if current_tool_call_matches:
tool_id, tool_args = current_tool_call_matches.groups() tool_id, tool_args = current_tool_call_matches.groups()
tool_name = tool_id.split(":")[0].split(".")[-1] tool_name = tool_id.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id.strip() current_tool_call["id"] = tool_id
current_tool_call["name"] = tool_name current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args current_tool_call["arguments"] = tool_args
else: else:
...@@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser): ...@@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser):
if current_tool_call_name_matches: if current_tool_call_name_matches:
(tool_id_str,) = current_tool_call_name_matches.groups() (tool_id_str,) = current_tool_call_name_matches.groups()
tool_name = tool_id_str.split(":")[0].split(".")[-1] tool_name = tool_id_str.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id_str.strip() current_tool_call["id"] = tool_id_str
current_tool_call["name"] = tool_name current_tool_call["name"] = tool_name
current_tool_call["arguments"] = "" current_tool_call["arguments"] = ""
else: else:
...@@ -597,4 +597,4 @@ class KimiK2ToolParser(ToolParser): ...@@ -597,4 +597,4 @@ class KimiK2ToolParser(ToolParser):
except Exception: except Exception:
logger.exception("Error trying to handle streaming tool call.") logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID. return None # do not stream a delta. skip this token ID.
\ No newline at end of file
...@@ -331,7 +331,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None: ...@@ -331,7 +331,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
partial_rotary_factor = getattr_iter(config, names, None, warn=True) partial_rotary_factor = getattr_iter(config, names, None, warn=True)
ompe = getattr(config, "original_max_position_embeddings", None) ompe = getattr(config, "original_max_position_embeddings", None)
if Version(version("transformers")) < Version("5.0.0"): if Version(version("transformers")) < Version("5.0.0.dev0"):
# Transformers v4 installed, legacy config fields may be present # Transformers v4 installed, legacy config fields may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
config.rope_parameters = rope_scaling config.rope_parameters = rope_scaling
......
...@@ -14,7 +14,6 @@ from __future__ import annotations ...@@ -14,7 +14,6 @@ from __future__ import annotations
import importlib import importlib
_CLASS_TO_MODULE: dict[str, str] = { _CLASS_TO_MODULE: dict[str, str] = {
"AfmoeConfig": "vllm.transformers_utils.configs.afmoe", "AfmoeConfig": "vllm.transformers_utils.configs.afmoe",
"BagelConfig": "vllm.transformers_utils.configs.bagel", "BagelConfig": "vllm.transformers_utils.configs.bagel",
......
...@@ -398,7 +398,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = { ...@@ -398,7 +398,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor, "qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor,
"mimo_mtp": MimoMTPModelArchConfigConvertor, "mimo_mtp": MimoMTPModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor, "glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
"ernie_mtp": ErnieMTPModelArchConfigConvertor, "ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
......
...@@ -5,7 +5,6 @@ from typing import Any ...@@ -5,7 +5,6 @@ from typing import Any
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
import torch
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -25,7 +24,6 @@ elif current_platform.is_xpu(): ...@@ -25,7 +24,6 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm(): elif current_platform.is_rocm():
try: try:
from vllm._custom_ops import reshape_and_cache_cuda from vllm._custom_ops import reshape_and_cache_cuda
# from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
except ImportError: except ImportError:
...@@ -97,7 +95,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -97,7 +95,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if current_platform.is_rocm():
return True return True
return ( return (
get_flash_attn_version() == 3 get_flash_attn_version() == 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment