Commit 899a2db4 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1(ex fused_moe&models)

parent 78c1f9e5
......@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
......@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
......@@ -52,11 +54,13 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
......@@ -80,9 +84,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym,
kNvfp4Dynamic,
kNvfp4Static,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
......@@ -735,23 +736,47 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
# TRT LLM not supported with all2all yet.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_fp8_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
)
def create_weights(
......@@ -835,7 +860,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel(
self,
layer: FusedMoE,
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
......@@ -865,13 +890,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.kernel, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
......@@ -972,8 +995,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}"
)
assert self.moe_mk is not None
return self.moe_mk(
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1257,16 +1280,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
else:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weights_padding_cols = weights_padding_cols
layer.weight = Parameter(weight, requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
def apply(
self,
......@@ -1288,6 +1304,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
)
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
......@@ -1302,12 +1319,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
output_size = layer.output_size_per_partition
output_shape = [x.shape[0], output_size]
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
mm_args = (
x_fp4,
layer.weight,
......@@ -1316,7 +1327,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha,
output_dtype,
)
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
......@@ -1324,9 +1334,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
......@@ -1353,27 +1360,50 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
assert self.moe_quant_config is not None
assert self.experts_cls is not None
return make_nvfp4_moe_kernel_for_mkm(
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
)
def uses_weight_scale_2_pattern(self) -> bool:
......@@ -1498,7 +1528,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: FusedMoE) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
......@@ -1550,14 +1580,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
self.kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
......@@ -1658,8 +1689,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
)
else:
assert self.moe_mk is not None
return self.moe_mk(
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1675,4 +1706,4 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
\ No newline at end of file
......@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
gemm1_bias=layer.w13_bias, # fp32 per expert per channel
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
gemm1_beta=layer.gemm1_beta, # fp32 per expert
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
gemm2_bias=layer.w2_bias, # fp32 per expert per channel
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size, # padded to multiple of 256
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=self.num_experts,
routed_scaling_factor=None,
routing_method_type=1 if layer.renormalize else 0,
do_finalize=True,
router_logits.to(torch.bfloat16),
None, # routing_bias
x_quant,
x_scale,
layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
layer.global_num_experts,
layer.top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
layer.ep_rank * layer.local_num_experts, # local_expert_offset
self.num_experts, # local num experts
None, # routed_scaling_factor
1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
......@@ -1170,4 +1170,4 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
activation="swiglu_oai",
)
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
return hidden_states
return hidden_states
\ No newline at end of file
......@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
......@@ -32,6 +35,7 @@ logger = init_logger(__name__)
__all__ = [
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
#
......@@ -132,6 +136,17 @@ def reorder_w1w3_to_w3w1(
)
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant,
# args,
......@@ -526,4 +541,4 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
w2_scale = swizzle_blockscale(w2_scale)
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
\ No newline at end of file
......@@ -4,8 +4,15 @@ from enum import Enum
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
......@@ -156,6 +163,18 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
......@@ -293,4 +312,4 @@ def prepare_fp8_moe_layer_for_fi(
w2_input_scale=w2_input_scale,
)
return w13, w2, w13_scale
return w13, w2, w13_scale
\ No newline at end of file
......@@ -16,7 +16,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
logger = init_logger(__name__)
......@@ -675,66 +674,6 @@ def apply_awq_marlin_linear(
return output.reshape(out_shape)
def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)
def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor,
data_num_0: int, data_num_1: int) -> torch.Tensor:
"""
......
......@@ -174,15 +174,15 @@ def get_rope(
scaling_factor = rope_parameters["factor"]
scaling_alpha = rope_parameters["alpha"]
if scaling_alpha:
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
......
......@@ -6,8 +6,6 @@ from dataclasses import dataclass
from vllm import envs
import os
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
......
......@@ -37,12 +37,6 @@ def gemm_bank_conf(weight):
return True
else:
return False
def set_random_seed(seed: int | None) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
def set_weight_attrs(
......
......@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
......@@ -168,10 +169,9 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# modular kernels could invoke deep_gemm_moe_fp8
return True
mk: FusedMoEModularKernel = module.quant_method.fused_experts
# Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import gc
import json
import os
import time
......@@ -14,18 +15,65 @@ import ray
import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import set_random_seed
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))
def clear_triton_cache():
"""Clear Triton JIT compilation cache and Python/CUDA memory.
This helps prevent OOM during tuning with large models (many experts).
"""
# Force Python garbage collection
gc.collect()
# Clear CUDA memory cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Try to clear Triton's runtime cache
try:
if (
hasattr(triton, "runtime")
and hasattr(triton.runtime, "cache")
and hasattr(triton.runtime.cache, "clear")
):
triton.runtime.cache.clear()
except ImportError:
# Triton not installed, skip cache clearing
pass
except AttributeError:
# Triton version doesn't have expected cache API
pass
except Exception as e:
print(f"Warning: Failed to clear Triton cache: {e}")
# Additional garbage collection after clearing caches
gc.collect()
def ensure_divisibility(numerator, denominator, text):
"""Ensure that numerator is divisible by the denominator."""
......@@ -195,10 +243,36 @@ def benchmark_config(
block_shape=block_quant_shape,
)
deep_gemm_experts = None
if use_deep_gemm:
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
activation="silu",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
device="cuda",
),
quant_config=quant_config,
),
)
with override_config(config):
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
if use_deep_gemm:
return deep_gemm_experts(
x, w1, w2, topk_weights, topk_ids, inplace=True
)
return fused_experts(
x,
w1,
......@@ -207,7 +281,7 @@ def benchmark_config(
topk_ids,
inplace=True,
quant_config=quant_config,
allow_deep_gemm=use_deep_gemm,
use_nn_moe=nn_moe,
)
# JIT compilation & warmup
......@@ -227,8 +301,8 @@ def benchmark_config(
# run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
......@@ -252,10 +326,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_k_range = [32, 64, 128, 256]
if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [2, 4, 8]
group_m_range = [1, 16, 32, 64]
num_stage_range = [2, 3, 4, 5]
# waves_per_eu_range = [0]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2]
# waves_per_eu_range = [0, 1, 2, 4]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
......@@ -453,7 +527,8 @@ class BenchmarkWorker:
pass
else:
torch.set_default_device("cuda:"+ str(device_id))
current_platform.seed_everything(seed)
set_random_seed(seed)
self.seed = seed
# Store the logical device ID for Ray
self.device_id = device_id
......@@ -474,7 +549,10 @@ class BenchmarkWorker:
) -> tuple[dict[str, int], float]:
# 局部导入 current_platform
from vllm.platforms import current_platform
current_platform.seed_everything(self.seed)
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_configs, get_default_config
set_random_seed(self.seed)
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
......@@ -559,7 +637,7 @@ class BenchmarkWorker:
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space):
for idx, config in enumerate(tqdm(search_space)):
try:
kernel_time = benchmark_config(
config,
......@@ -582,6 +660,19 @@ class BenchmarkWorker:
if kernel_time < best_time:
best_time = kernel_time
best_config = config
# Periodically clear Triton JIT cache to prevent OOM
# This is especially important for large models with many experts
if (
TRITON_CACHE_CLEAR_INTERVAL > 0
and idx > 0
and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
):
clear_triton_cache()
# Final cleanup after tuning completes
clear_triton_cache()
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
......@@ -668,19 +759,24 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM",
"Glm4MoeLiteForCausalLM",
"NemotronHForCausalLM",
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in (
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
......@@ -689,14 +785,27 @@ def main(args: argparse.Namespace):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts
topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] in ("Step3VLForConditionalGeneration"):
E = config.text_config.moe_num_experts
topk = config.text_config.moe_top_k
intermediate_size = config.text_config.moe_intermediate_size
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
else:
# Support for llama4
config = config.get_text_config()
......@@ -704,16 +813,16 @@ def main(args: argparse.Namespace):
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
enable_ep = bool(args.enable_expert_parallel)
if enable_ep:
ensure_divisibility(E, tp_size, "Number of experts")
E = E // tp_size
shard_intermediate_size = 2 * intermediate_size
else:
ensure_divisibility(intermediate_size, tp_size, "intermediate_size")
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
......
......@@ -152,32 +152,32 @@ def use_rocm_custom_paged_attention(
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
if ON_GFX9:
return (
(sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and sinks is None
)
else:
return (
ON_GFX11_GFX12
and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128
and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
and sinks is None
)
# if ON_GFX9:
# return (
# (sliding_window == 0 or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and (head_size == 64 or head_size == 128)
# and (block_size == 16 or block_size == 32)
# and (gqa_ratio >= 1 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024
# and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
# and sinks is None
# )
# else:
# return (
# ON_GFX11_GFX12
# and (sliding_window == 0 or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and head_size == 128
# and block_size == 16
# and (gqa_ratio >= 3 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024
# and alibi_slopes is None
# and kv_cache_dtype == "auto"
# and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
# and sinks is None
# )
return False
......
......@@ -16,20 +16,10 @@ class FilesystemResolver(LoRAResolver):
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
lora_path = os.path.join(self.lora_cache_dir, lora_name)
maybe_lora_request = await self._get_lora_req_from_path(
lora_name, lora_path, base_model_name
)
return maybe_lora_request
async def _get_lora_req_from_path(
self, lora_name: str, lora_path: str, base_model_name: str
) -> LoRARequest | None:
"""Builds a LoraRequest pointing to the lora path if it's a valid
LoRA adapter and has a matching base_model_name.
"""
if os.path.exists(lora_path):
adapter_config_path = os.path.join(lora_path, "adapter_config.json")
adapter_config_path = os.path.join(
self.lora_cache_dir, lora_name, "adapter_config.json"
)
if os.path.exists(adapter_config_path):
with open(adapter_config_path) as file:
adapter_config = json.load(file)
......@@ -59,4 +49,4 @@ def register_filesystem_resolver():
fs_resolver = FilesystemResolver(lora_cache_dir)
LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver)
return
return
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from huggingface_hub import HfApi, snapshot_download
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolverRegistry
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
logger = init_logger(__name__)
class HfHubResolver(FilesystemResolver):
def __init__(self, repo_list: list[str]):
logger.warning(
"LoRA is allowing resolution from the following repositories on"
" HF Hub: %s please note that allowing remote downloads"
" is not secure, and that this plugin is not intended for use in"
" production environments.",
repo_list,
)
self.repo_list: list[str] = repo_list
self.adapter_dirs: dict[str, set[str]] = {}
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
"""Resolves potential LoRA requests in a remote repo on HF Hub.
This is effectively the same behavior as the filesystem resolver, but
with a snapshot_download on dirs containing an adapter config prior
to inspecting the cached dir to build a potential LoRA
request.
"""
# If a LoRA name begins with the repository name, it's disambiguated
maybe_repo = await self._resolve_repo(lora_name)
# If we haven't inspected this repo before, save available adapter dirs
if maybe_repo is not None and maybe_repo not in self.adapter_dirs:
self.adapter_dirs[maybe_repo] = await self._get_adapter_dirs(maybe_repo)
maybe_subpath = await self._resolve_repo_subpath(lora_name, maybe_repo)
if maybe_repo is None or maybe_subpath is None:
return None
repo_path = await asyncio.to_thread(
snapshot_download,
repo_id=maybe_repo,
allow_patterns=f"{maybe_subpath}/*" if maybe_subpath != "." else "*",
)
lora_path = os.path.join(repo_path, maybe_subpath)
maybe_lora_request = await self._get_lora_req_from_path(
lora_name, lora_path, base_model_name
)
return maybe_lora_request
async def _resolve_repo(self, lora_name: str) -> str | None:
"""Given a fully qualified path to a LoRA with respect to its HF Hub
repo, match the right repo to potentially download from if one exists.
Args:
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>,
match on <org>/<repo> (if it contains an adapter directly) or
<org>/<repo>/ if it may have one in subdirs.
"""
for potential_repo in self.repo_list:
if lora_name.startswith(potential_repo) and (
len(lora_name) == len(potential_repo)
or lora_name[len(potential_repo)] == "/"
):
return potential_repo
return None
async def _resolve_repo_subpath(
self, lora_name: str, maybe_repo: str | None
) -> str | None:
"""Given the fully qualified path of the LoRA with respect to the HF
Repo, get the subpath to download from assuming it's actually got an
adapter in it.
Args:
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>
maybe_repo: Path to the repo to match against if one exists.
"""
if maybe_repo is None:
return None
repo_len = len(maybe_repo)
if lora_name == maybe_repo or (
len(lora_name) == repo_len + 1 and lora_name[-1] == "/"
):
# Resolves to the root of the directory
adapter_dir = "."
else:
# It's a subpath; removing trailing slashes if there are any
adapter_dir = lora_name[repo_len + 1 :].rstrip("/")
# Only download if the directory actually contains an adapter
is_adapter = adapter_dir in self.adapter_dirs[maybe_repo]
return adapter_dir if is_adapter else None
async def _get_adapter_dirs(self, repo_name: str) -> set[str]:
"""Gets the subpaths within a HF repo that contain an adapter config.
Args:
repo_name: Name of the HF hub repo to inspect.
"""
repo_files = await asyncio.to_thread(HfApi().list_repo_files, repo_id=repo_name)
adapter_dirs = {
os.path.dirname(name)
for name in repo_files
if name.endswith("adapter_config.json")
}
if "adapter_config.json" in repo_files:
adapter_dirs.add(".")
return adapter_dirs
def register_hf_hub_resolver():
"""Register the Hf hub LoRA Resolver with vLLM"""
hf_repo_list = envs.VLLM_LORA_RESOLVER_HF_REPO_LIST
is_enabled = (
envs.VLLM_PLUGINS is not None and "lora_hf_hub_resolver" in envs.VLLM_PLUGINS
)
if hf_repo_list:
if not is_enabled:
logger.warning(
"It appears that VLLM_LORA_RESOLVER_HF_REPO_LIST is set, but "
"lora_hf_hub_resolver is not enabled in VLLM_PLUGINS; you must"
" enable this resolver directly in VLLM_PLUGINS to use it "
" because it allows remote downloads."
)
else:
hf_hub_resolver = HfHubResolver(hf_repo_list.split(","))
LoRAResolverRegistry.register_resolver("Hf Hub Resolver", hf_hub_resolver)
return
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
__all__ = [
......
......@@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser):
if current_tool_call_matches:
tool_id, tool_args = current_tool_call_matches.groups()
tool_name = tool_id.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id.strip()
current_tool_call["id"] = tool_id
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
......@@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser):
if current_tool_call_name_matches:
(tool_id_str,) = current_tool_call_name_matches.groups()
tool_name = tool_id_str.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id_str.strip()
current_tool_call["id"] = tool_id_str
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
......@@ -597,4 +597,4 @@ class KimiK2ToolParser(ToolParser):
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.
return None # do not stream a delta. skip this token ID.
\ No newline at end of file
......@@ -331,7 +331,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
partial_rotary_factor = getattr_iter(config, names, None, warn=True)
ompe = getattr(config, "original_max_position_embeddings", None)
if Version(version("transformers")) < Version("5.0.0"):
if Version(version("transformers")) < Version("5.0.0.dev0"):
# Transformers v4 installed, legacy config fields may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
config.rope_parameters = rope_scaling
......
......@@ -14,7 +14,6 @@ from __future__ import annotations
import importlib
_CLASS_TO_MODULE: dict[str, str] = {
"AfmoeConfig": "vllm.transformers_utils.configs.afmoe",
"BagelConfig": "vllm.transformers_utils.configs.bagel",
......
......@@ -398,7 +398,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor,
"mimo_mtp": MimoMTPModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
......
......@@ -5,7 +5,6 @@ from typing import Any
from vllm.logger import init_logger
from vllm.platforms import current_platform
import torch
logger = init_logger(__name__)
......@@ -25,7 +24,6 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm():
try:
from vllm._custom_ops import reshape_and_cache_cuda
# from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
except ImportError:
......@@ -97,7 +95,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
if current_platform.is_rocm():
return True
return (
get_flash_attn_version() == 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment