Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}
......@@ -3,61 +3,17 @@
from typing import Callable, Optional
import torch
from torch.nn import functional as F
from vllm import envs
class IPEXFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)
class SGLFusedMOE:
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def __init__(self, layer: torch.nn.Module) -> None:
pass
@staticmethod
def _grouped_topk(
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
......@@ -65,8 +21,9 @@ class SGLFusedMOE:
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
......@@ -80,9 +37,6 @@ class SGLFusedMOE:
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use
# biased scores for expert selection but original scores for
# routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
......@@ -90,22 +44,18 @@ class SGLFusedMOE:
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores,
k=topk_group,
dim=-1,
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token,
-1) # [n, e]
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
......@@ -114,13 +64,14 @@ class SGLFusedMOE:
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids.to(torch.int32)
@staticmethod
def _select_experts(
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
......@@ -130,20 +81,20 @@ class SGLFusedMOE:
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# DeekSeekv2 uses grouped_top_k
) -> tuple[torch.Tensor, torch.Tensor]:
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = SGLFusedMOE._grouped_topk(
hidden_states=hidden_states,
return grouped_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
......@@ -153,15 +104,65 @@ class SGLFusedMOE:
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids.to(torch.int32)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
return custom_routing_function(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
class IPEXFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
assert routed_scaling_factor == 1.0, \
f"routed_scaling_factor {routed_scaling_factor} is not supported."
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)
class SGLFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
def __call__(
self,
......@@ -177,13 +178,14 @@ class SGLFusedMOE:
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = SGLFusedMOE._select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -193,6 +195,7 @@ class SGLFusedMOE:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
......@@ -213,3 +216,82 @@ class SGLFusedMOE:
True,
)
return x
class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
layer_w13_weight = layer.w13_weight[i]
layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs,
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (new_x.view(
*topk_ids.shape, -1).type(topk_weights.dtype).mul_(
topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
return final_out
......@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.scalar_type import scalar_types
......@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
......@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
):
a1q = hidden_states
......@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.size(1)
local_E = w1.size(0)
if use_batched_format:
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(local_E * padded_M, N))
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
else:
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M * topk, N))
mm2_out = _resize_cache(workspace2, (M * topk, K))
if use_batched_format:
assert expert_num_tokens is not None
......@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
a1q = a1q.reshape(-1, a1q.size(2))
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
......@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype=torch.int32,
device=device)
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
if expert_map is not None:
a_map = torch.zeros((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
else:
a_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
num_expert = global_num_experts if expert_map is None \
else expert_map.size(0)
# permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a1q,
a1q_scale,
topk_ids,
num_expert,
local_E,
expert_map,
permuted_hidden_states=a1q_perm)
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
problem_sizes2,
global_num_experts, N, K)
if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched .
c1.fill_(0)
mm1_out.fill_(0)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
activation_callable(c2, c1)
activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
act_out,
a2_scale,
use_per_token_if_dynamic=per_act_token,
output=quant_out)
if expert_map is not None:
c3.fill_(0)
mm2_out.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
if use_batched_format:
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute(out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
......@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
......@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape,
))
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
......@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
use_batched_format, topk_weights)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
......@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
)
......@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
a: torch.Tensor,
......@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
......@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
)
assert max_experts_per_worker > 0
......@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
......@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
......@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
......@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
),
)
......
......@@ -7,6 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
......@@ -59,8 +61,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_act_token_quant=False,
block_shape=None,
))
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
"currently supported.")
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
"Only nvfp4,fp8 quantization are currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
......@@ -120,7 +122,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
aq_m, aq_n = aq.shape
workspace2 = ()
output_shape = (aq_m, aq_n * 2)
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
torch.float8_e4m3fn else (aq_m, aq_n)
workspace_dtype = a.dtype
workspace1 = output_shape
# The workspace is determined by `aq`, since it comes after any
......@@ -149,14 +152,21 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
):
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
if self.quant_dtype == torch.float8_e4m3fn:
quant_scales = [
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
]
a1q_scale = None # not passing input_sf in fp8
fc1_expert_weights = w1
fc2_expert_weights = w2
else:
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
quant_scales = [
self.a1_gscale,
w1_scale.view(torch.int32),
......@@ -165,13 +175,16 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2_scale.view(torch.int32),
self.g2_alphas,
]
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights=w1.view(torch.long),
fc2_expert_weights=w2.view(torch.long),
fc1_expert_weights=fc1_expert_weights,
fc2_expert_weights=fc2_expert_weights,
output_dtype=self.out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
......@@ -181,3 +194,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_rank=self.ep_rank,
output=output,
)
def flashinfer_cutlass_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
a1_gscale=a1_gscale),
FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=hidden_states.dtype,
quant_dtype="nvfp4",
))
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -40,7 +40,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
......@@ -975,8 +975,23 @@ def grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \
current_platform.is_cuda() and \
num_expert_group <= 32 and topk <= 32 and \
e_score_correction_bias is not None:
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor)
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
......@@ -1022,9 +1037,39 @@ def grouped_topk(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
topk_values, topk_indices = ops.grouped_topk(
scores, scores_with_bias.to(scores.dtype), num_expert_group,
topk_group, topk, renormalize, routed_scaling_factor)
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
......@@ -1420,9 +1465,8 @@ def fused_experts(hidden_states: torch.Tensor,
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
if (allow_deep_gemm and use_fp8_w8a8
and (is_blackwell_deep_gemm_e8m0_used()
or _valid_deep_gemm(hidden_states, w1, w2))):
if (allow_deep_gemm and use_fp8_w8a8 and
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
......
......@@ -200,7 +200,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
else:
return None
def init_prepare_finalize(self):
# Note: init_prepare_finalize should only be called by
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self, layer: torch.nn.Module):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
......@@ -211,7 +213,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe)
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
......@@ -221,6 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
......@@ -243,6 +246,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -275,6 +279,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
......@@ -374,12 +379,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_prepack=True,
)
elif current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
dtype = layer.w13_weight.dtype
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.utils import (
check_cpu_sgl_kernel)
dtype_w13 = layer.w13_weight.dtype
_, n_w13, k_w13 = layer.w13_weight.size()
dtype_w2 = layer.w2_weight.dtype
_, n_w2, k_w2 = layer.w2_weight.size()
if (envs.VLLM_CPU_SGL_KERNEL
and torch._C._cpu._is_amx_tile_supported()
and dtype == torch.bfloat16):
and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)):
packed_w13_weight = torch.ops._C.convert_weight_packed(
layer.w13_weight)
assert packed_w13_weight.size() == layer.w13_weight.size()
......@@ -393,7 +403,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
else:
raise NotImplementedError("CPU MOE only supports x86 arch.")
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
def apply(
self,
......@@ -409,6 +419,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -437,6 +448,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......@@ -461,6 +473,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -481,6 +494,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
......@@ -547,6 +561,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -574,6 +589,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map,
custom_routing_function,
scoring_func,
routed_scaling_factor,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
......@@ -593,6 +609,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -631,6 +648,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
use_nn_moe: Optional[bool] = False,
......@@ -652,6 +670,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")
assert activation == "silu", f"{activation} is not supported for TPU."
assert routed_scaling_factor == 1.0, \
f"routed_scaling_factor {routed_scaling_factor} is not supported " \
f"for TPU."
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
......@@ -670,6 +691,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu
elif current_platform.is_xpu():
forward_native = forward_xpu
else:
forward_native = forward_cuda
......@@ -717,6 +740,26 @@ def determine_expert_map(
return (local_num_experts, expert_map)
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
"""
Compresses the expert map by removing any -1 entries.
Args:
expert_map (torch.Tensor): A tensor of shape (global_num_experts,)
mapping from global to local index. Contains -1 for experts not
assigned to the current rank.
Returns:
str: A string mapping from local to global index.
Using str to support hashing for logging once only.
"""
global_indices = torch.where(expert_map != -1)[0]
local_indices = expert_map[global_indices]
return ", ".join(
f"{local_index.item()}->{global_index.item()}"
for local_index, global_index in zip(local_indices, global_indices))
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
......@@ -759,6 +802,7 @@ class FusedMoE(CustomOp):
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
......@@ -817,6 +861,12 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
" number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map))
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
......@@ -835,6 +885,7 @@ class FusedMoE(CustomOp):
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.e_score_correction_bias = e_score_correction_bias
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
......@@ -917,7 +968,7 @@ class FusedMoE(CustomOp):
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
or self.moe_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
......@@ -971,7 +1022,7 @@ class FusedMoE(CustomOp):
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
return self.moe_config.use_flashinfer_cutlass_kernels
def update_expert_map(self):
# ep_size and ep_rank should already be updated
......@@ -1423,6 +1474,7 @@ class FusedMoE(CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
enable_eplb: bool = False,
......@@ -1467,6 +1519,7 @@ class FusedMoE(CustomOp):
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
......@@ -1634,6 +1687,7 @@ class FusedMoE(CustomOp):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
enable_eplb=self.enable_eplb,
......@@ -1674,7 +1728,7 @@ class FusedMoE(CustomOp):
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
and self.moe_config.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
......@@ -1683,7 +1737,7 @@ class FusedMoE(CustomOp):
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
and not self.moe_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
......@@ -1702,6 +1756,7 @@ class FusedMoE(CustomOp):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
......
......@@ -3,7 +3,6 @@
import torch
import torch.nn.functional as F
import torch_xla.experimental.custom_kernel # noqa: F401
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
......@@ -41,6 +40,7 @@ def fused_moe(
gating_output: [*, num_experts]
"""
assert expert_map is None, "expert_map is not supported for pallas MoE."
import torch_xla.experimental.custom_kernel # noqa: F401
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
......
......@@ -82,7 +82,8 @@ def moe_permute(
n_local_expert: int = -1,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1
fill_invalid_expert: int = -1,
permuted_hidden_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
......@@ -100,9 +101,12 @@ def moe_permute(
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
......@@ -122,11 +126,16 @@ def moe_permute(
1) // align_block_size * align_block_size
if n_local_expert == -1:
n_local_expert = n_expert
if permuted_hidden_states is None:
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
f" but got {permuted_hidden_states.size()}")
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
......@@ -153,7 +162,8 @@ def moe_permute(
align_block_size, permuted_hidden_states,
expert_first_token_offset, inv_permuted_idx,
permuted_idx, m_indices)
if a1q_scale is not None:
if a1q_scale is not None and a1q_scale.dim() > 1:
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
topk]
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
......@@ -185,6 +195,7 @@ def moe_unpermute(
n_hidden = permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
inv_permuted_idx, expert_first_token_offset,
topk, out)
......
......@@ -267,6 +267,7 @@ def rocm_aiter_grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
......@@ -279,7 +280,7 @@ def rocm_aiter_grouped_topk(
if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output,
e_score_correction_bias,
e_score_correction_bias.to(gating_output.dtype),
topk_weights,
topk_ids,
num_expert_group,
......@@ -298,6 +299,8 @@ def rocm_aiter_grouped_topk(
scoring_func,
)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids
......
......@@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
......@@ -107,7 +107,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
if self.allow_deep_gemm and (is_deep_gemm_e8m0_used()
or _valid_deep_gemm_shape(M, N, K)):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
......@@ -143,7 +143,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_e8m0_used()))
or is_deep_gemm_e8m0_used()))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils import next_power_of_2
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe: FusedMoEConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
w13_bias,
w2_bias,
max_capture_size,
):
super().__init__(moe.quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.w13_bias = w13_bias
self.w2_bias = w2_bias
self.max_capture_size = max_capture_size
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This
# is error-prone. Allow the `workspace_shapes` to return None workspaces
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
local_num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# 1.0 means perfect expert distribution.
# > 1.0 means some experts have more tokens than the perfect
# distribution.
# < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert assuming perfect
# distribution.
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
# kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
local_expert_offset = self.moe.ep_rank * local_num_experts
x_quant = hidden_states
x_scale = a1q_scale
if x_scale is not None:
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*x_quant.shape[:-1], -1)
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16)
assert w1_scale is not None
assert w2_scale is not None
kwargs = {
"topk_ids":
packed_tensor,
"routing_bias":
None,
"hidden_states":
x_quant,
"hidden_states_scale":
x_scale,
"gemm1_weights":
w1,
"gemm1_weights_scale":
w1_scale,
"gemm1_bias":
self.w13_bias,
"gemm1_alpha":
self.gemm1_alpha,
"gemm1_beta":
self.gemm1_beta,
"gemm1_clamp_limit":
self.gemm1_clamp_limit,
"gemm2_weights":
w2,
"gemm2_weights_scale":
w2_scale,
"gemm2_bias":
self.w2_bias,
"output1_scale_scalar":
None,
"output1_scale_gate_scalar":
None,
"output2_scale_scalar":
None,
"num_experts":
global_num_experts,
"top_k":
topk,
"n_group":
None,
"topk_group":
None,
"intermediate_size":
intermediate_size,
"local_expert_offset":
local_expert_offset,
"local_num_experts":
local_num_experts,
"routed_scaling_factor":
None,
"tile_tokens_dim":
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
"routing_method_type":
1,
"do_finalize":
True,
"output":
output,
"tune_max_num_tokens":
self.max_capture_size,
}
from flashinfer import trtllm_fp4_block_scale_routed_moe
trtllm_fp4_block_scale_routed_moe(**kwargs)
return output
......@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_quantize)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
......@@ -177,6 +179,18 @@ def _mxfp4_quantize(
return A, None
def _mxfp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
return mxfp8_quantize(A)
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
......@@ -195,6 +209,8 @@ def moe_kernel_quantize_input(
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp8":
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from einops import rearrange
......@@ -453,7 +455,14 @@ class _attention(torch.autograd.Function):
lightning_attention_ = _attention.apply
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
def lightning_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
ed: torch.Tensor,
block_size: int = 256,
kv_history: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply lightning attention algorithm
to compute attention efficiently.
......
......@@ -35,6 +35,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod",
"BitBLASLinearMethod",
"GPTQBitBLASLinearMethod",
"AWQMarlinLinearMethod",
......@@ -42,7 +43,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
......@@ -53,6 +53,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"HQQMarlinMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
"PetitNvFp4LinearMethod",
]
......@@ -199,12 +200,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# special postprocessing for CPU SGL
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
N, K = layer.weight.size()
dtype = layer.weight.dtype
if (torch._C._cpu._is_amx_tile_supported()
and dtype == torch.bfloat16 and N % 32 == 0
and K % 32 == 0):
if check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(
layer.weight)
assert packed_weight.size() == layer.weight.size()
......@@ -216,7 +217,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
else:
logger.warning(
"CPU SGL kernels require Intel AMX support,"
" bfloat16 weight, IC and OC are divisible by 32.")
" bf16/fp16/int8 weight, IC and OC are divisible by "
"32 and 16.")
layer.use_cpu_sgl = False
def apply(self,
......@@ -233,10 +235,10 @@ class LinearBase(CustomOp):
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: Prefix for parameter names.
return_bias: If true, return bias together with outputs in forward pass.
"""
......@@ -378,13 +380,14 @@ class MergedReplicatedLinear(ReplicatedLinear):
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
......@@ -437,7 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
param[shard_offset:shard_offset + shard_size] = loaded_weight
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
......@@ -1378,7 +1381,7 @@ class RowParallelLinear(LinearBase):
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s = f"in_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
......@@ -1469,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
"weight_loader": self.weight_loader_v1,
})
else:
self.bias = None
......@@ -1579,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v
def weight_loader_v1(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# just like all other parameters, does not yet
# support loading bias with weight_loader_v2
layer = (self.q_proj_decoder
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
class MambaBase(ABC):
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
......@@ -32,3 +38,8 @@ class MambaBase(ABC):
@abstractmethod
def mamba_type(self) -> str:
pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from typing import TYPE_CHECKING
import torch
import torch.distributed
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.lightning_attn import (
lightning_attention, linear_decode_forward_triton)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
import torch.distributed
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size /
self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]
x = x.to(orig_dtype) * weight
return x
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: Optional[int] = None,
**kwargs) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(q,
k,
v,
slope_rate,
block_size=block_size,
kv_history=kv_history)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
tp_size=self.tp_size,
head_dim=self.head_dim)
def __init__(
self,
hidden_size: int,
hidden_inner_size: int,
num_heads: int,
head_dim: int,
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.BLOCK = block_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.total_num_heads = num_heads
self.hidden_inner_size = hidden_inner_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.total_num_heads % self.tp_size == 0
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.output_gate = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.output_gate",
)
self.out_proj = RowParallelLinear(
self.hidden_inner_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = MiniMaxText01RMSNormTP(
self.hidden_inner_size,
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (1 - layer_idx /
(num_hidden_layer - 1) + 1e-5)
self.tp_slope = self.slope_rate[self.tp_rank *
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
slopes = torch.tensor(get_slopes(n_attention_heads),
dtype=torch.float32).reshape(
n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
num_actual_tokens = hidden_states.shape[0]
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
assert kv_caches is not None
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
else:
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
output[:num_actual_tokens], _ = self.out_proj(hidden)
def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)
def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
......@@ -27,6 +30,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
......@@ -183,22 +188,26 @@ class MambaMixer(MambaBase, CustomOp):
def forward(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
else:
return self.forward_cuda(
torch.ops.vllm.mamba_mixer(
hidden_states,
mamba_cache_params,
output,
self.prefix,
)
def forward_native(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
pass
def forward_cuda(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
"""
Run the Mamba-1 SSM pipeline.
......@@ -237,6 +246,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
......@@ -248,6 +258,7 @@ class MambaMixer(MambaBase, CustomOp):
has_initial_states = None
if context_lens_tensor is not None:
has_initial_states = context_lens_tensor > 0
num_padded_decodes = attn_metadata.num_decode_tokens
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
......@@ -267,6 +278,7 @@ class MambaMixer(MambaBase, CustomOp):
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
has_prefill = num_prefill_tokens > 0
has_decode = num_decode_tokens > 0
num_actual_tokens = num_prefill_tokens + num_decode_tokens
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
......@@ -278,6 +290,7 @@ class MambaMixer(MambaBase, CustomOp):
num_decode_tokens,
num_prefills,
num_decodes,
num_padded_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
......@@ -371,7 +384,7 @@ class MambaMixer(MambaBase, CustomOp):
else:
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
return out
output[:num_actual_tokens] = out
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
......@@ -394,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
......@@ -421,18 +439,27 @@ def split_batch_to_prefill_and_decode(
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
gate_d, gate_p = torch.split(gate,
[num_decode_tokens, num_prefill_tokens],
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor, [num_decodes, num_prefills], dim=0)
state_indices_tensor[:num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_decodes if num_prefills > 0 else None)
num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
......@@ -459,3 +486,32 @@ def split_batch_to_prefill_and_decode(
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)
def mamba_mixer(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None)
def mamba_mixer_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mamba_mixer",
op_func=mamba_mixer,
mutates_args=["output"],
fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
......@@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend
def mamba_mixer2(
hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment