"tests/vscode:/vscode.git/clone" did not exist on "fecae12cd7deb969dcbba37fda9d2d234697a944"
Unverified Commit 845adb3e authored by XuruiYang's avatar XuruiYang Committed by GitHub
Browse files

[Model] Add LongCat-Flash (#23991)


Signed-off-by: default avataryangxurui <yangxurui@meituan.com>
Co-authored-by: default avataryangxurui <yangxurui@meituan.com>
parent 90b139cf
......@@ -44,6 +44,9 @@ __global__ void moe_align_block_size_kernel(
for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
......@@ -95,12 +98,15 @@ template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
size_t numel) {
size_t numel, int32_t num_experts) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
......@@ -269,7 +275,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts);
}
});
}
......
......@@ -428,6 +428,7 @@ th {
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ |
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
......
......@@ -138,7 +138,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
......@@ -206,7 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
......
......@@ -273,6 +273,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
is_available_online=False),
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False),
"LongcatFlashForCausalLM": _HfExamplesInfo
("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
......@@ -639,6 +641,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.54",
is_available_online=False),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
trust_remote_code=True,
speculative_model="meituan-longcat/LongCat-Flash-Chat"),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
......
......@@ -428,9 +428,8 @@ def dummy_hf_overrides(
num_hidden_layers = (3 if model_arch
== "Gemma3nForConditionalGeneration" else 1)
text_config.update({
update_dict = {
"num_layers": num_layers,
"num_hidden_layers": num_hidden_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
......@@ -440,7 +439,14 @@ def dummy_hf_overrides(
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
}
# Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \
and model_arch != "LongCatFlashMTPModel":
update_dict["num_hidden_layers"] = num_hidden_layers
text_config.update(update_dict)
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
......
......@@ -96,7 +96,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
......
......@@ -1131,7 +1131,8 @@ class ModelConfig:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
'kimi_k2', 'longcat_flash'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
......@@ -1257,6 +1258,9 @@ class ModelConfig:
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
elif (self.hf_config.model_type == "longcat_flash_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 1)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
......
......@@ -31,7 +31,8 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]
@config
......@@ -186,6 +187,13 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({
"n_predict": n_predict,
"architectures": ["LongCatFlashMTPModel"]
})
return hf_config
......@@ -332,6 +340,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
......@@ -548,7 +565,7 @@ class SpeculativeConfig:
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
"qwen3_next_mtp", "longcat_flash_mtp")
def __repr__(self) -> str:
method = self.method
......
......@@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
)
@triton.jit
def compute_identity_kernel(
top_k: int,
hidden_states_ptr: tl.tensor,
expert_scales_ptr: tl.tensor,
num_tokens: int,
output_ptr: tl.tensor,
hidden_dim: int,
scales_stride: int,
BLOCK_SIZE: tl.constexpr,
) -> None:
pid = tl.program_id(0)
batch_id = pid // (hidden_dim // BLOCK_SIZE)
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
if batch_id >= num_tokens or dim_offset >= hidden_dim:
return
h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset +
tl.arange(0, BLOCK_SIZE),
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for i in range(top_k):
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
result += h * scale
tl.store(output_ptr + batch_id * hidden_dim + dim_offset +
tl.arange(0, BLOCK_SIZE),
result,
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
def zero_experts_compute_triton(expert_indices: torch.Tensor,
expert_scales: torch.Tensor, num_experts: int,
zero_expert_type: str,
hidden_states: torch.Tensor) -> torch.Tensor:
N = expert_indices.numel()
top_k = expert_indices.size(-1)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales[zero_expert_mask] = 0.0
normal_expert_mask = expert_indices >= num_experts
expert_indices[normal_expert_mask] = 0
expert_scales[normal_expert_mask] = 0.0
output = torch.zeros_like(hidden_states).to(hidden_states.device)
hidden_dim = hidden_states.size(-1)
num_tokens = hidden_states.size(0)
grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), )
compute_identity_kernel[grid](
top_k,
hidden_states,
zero_expert_scales,
num_tokens,
output,
hidden_dim,
zero_expert_scales.stride(0),
BLOCK_SIZE=256,
)
return output
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(E: int,
N: int,
......@@ -940,6 +1010,25 @@ def fused_topk(
return topk_weights, topk_ids, token_expert_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts) + e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1,
sorted=False)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk(
......
......@@ -24,6 +24,8 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, biased_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import (
zero_experts_compute_triton)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
......@@ -548,7 +550,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
topk_weights, topk_ids = FusedMoE.select_experts(
zero_expert_num = getattr(layer, 'zero_expert_num', 0)
zero_expert_type = getattr(layer, 'zero_expert_type', None)
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -565,11 +570,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type)
if self.rocm_aiter_moe_enabled:
assert self.fused_experts is None
return self.rocm_aiter_fused_experts(
result = self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -591,7 +599,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if self.moe.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -605,7 +613,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
else:
assert fused_experts is not None
return fused_experts(
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -619,6 +627,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
)
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), \
"Shared + zero experts are mutually exclusive not yet supported"
return result, zero_expert_result
else:
return result
def forward_cpu(
self,
layer: torch.nn.Module,
......@@ -942,6 +957,8 @@ class FusedMoE(CustomOp):
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
):
super().__init__()
if params_dtype is None:
......@@ -976,6 +993,8 @@ class FusedMoE(CustomOp):
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts + num_redundant_experts
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
......@@ -1656,25 +1675,30 @@ class FusedMoE(CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
global_num_experts: Optional[int] = None,
zero_expert_num: Optional[int] = None,
zero_expert_type: Optional[str] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]):
The weights and *global physical* expert ids of the top-k experts.
(topk_weights, topk_ids, zero_expert_result)
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
The weights, expert ids, and zero expert computation result.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, fused_topk_bias)
# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
return RoutingSimulator.simulate_routing(
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
......@@ -1697,6 +1721,16 @@ class FusedMoE(CustomOp):
e_score_correction_bias=e_score_correction_bias)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=e_score_correction_bias.data,
topk=top_k,
renormalize=renormalize,
)
if routed_scaling_factor is not None:
topk_weights *= routed_scaling_factor
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
......@@ -1729,7 +1763,20 @@ class FusedMoE(CustomOp):
assert topk_ids.dtype == indices_type or indices_type is None
return topk_weights, topk_ids
# Compute zero expert result if needed
if (zero_expert_num is not None and zero_expert_num > 0
and zero_expert_type is not None
and global_num_experts is not None):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=hidden_states,
)
else:
zero_expert_result = None
return topk_weights, topk_ids, zero_expert_result
def must_reduce_shared_expert_outputs(self) -> bool:
"""
......@@ -1878,6 +1925,11 @@ class FusedMoE(CustomOp):
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if isinstance(final_hidden_states, tuple):
final_hidden_states, zero_expert_result = final_hidden_states
if zero_expert_result is not None:
final_hidden_states += zero_expert_result
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[
......@@ -1992,6 +2044,9 @@ class FusedMoE(CustomOp):
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
......@@ -2003,14 +2058,16 @@ class FusedMoE(CustomOp):
return states
if self.shared_experts is None:
assert not isinstance(final_hidden_states, tuple)
return reduce_output(final_hidden_states)
else:
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(
......
......@@ -103,7 +103,6 @@ class MultiHeadLatentAttention(CustomOp):
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward_native(
self,
......
......@@ -520,7 +520,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -486,7 +486,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -385,7 +385,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
"`CompressedTensorsW4A4MoeMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -934,7 +934,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -1195,7 +1195,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -1502,7 +1502,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -1747,7 +1747,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -146,7 +146,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -174,6 +176,10 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
......@@ -927,6 +933,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
......@@ -943,8 +950,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
......@@ -965,7 +971,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
else:
assert (not renormalize
and custom_routing_function is not None)
return apply_flashinfer_per_tensor_scale_fp8(
result = apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
......@@ -976,7 +982,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
topk_weights, topk_ids = FusedMoE.select_experts(
zero_expert_num = getattr(layer, 'zero_expert_num', 0)
zero_expert_type = getattr(layer, 'zero_expert_type', None)
select_result = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -994,17 +1003,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
)
#
# Note: the order of checks is important since self.fused_experts
# can override fused_experts or cutlass but not rocm or marlin.
#
topk_weights, topk_ids, zero_expert_result = select_result
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts)
assert self.fused_experts is None
return rocm_aiter_fused_experts(
result = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1018,7 +1032,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
result = torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1035,7 +1049,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map=expert_map,
workspace=layer.workspace)
elif self.fused_experts:
return self.fused_experts(
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -1055,7 +1069,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert scoring_func == 'sigmoid', (
f"Expected 'sigmoid' scoring func but got {scoring_func}")
return flashinfer_cutlass_moe_fp8(
result = flashinfer_cutlass_moe_fp8(
x,
layer,
topk_weights,
......@@ -1068,7 +1082,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -1083,6 +1097,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), \
"Shared + zero experts are mutually exclusive not yet supported"
return result, zero_expert_result
else:
return result
class Fp8KVCacheMethod(BaseKVCacheMethod):
......
......@@ -555,7 +555,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for"
"fused GGUF MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -669,7 +669,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -543,7 +543,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input)
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -1491,7 +1491,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)[0]
return out
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -332,7 +332,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment