Commit 3d062a1c authored by zhuwenwen's avatar zhuwenwen
Browse files

update fusedmoe

parent fe79f042
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, List, Optional, Tuple
import torch import torch
...@@ -31,7 +31,7 @@ from vllm.platforms import current_platform ...@@ -31,7 +31,7 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -502,7 +502,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -502,7 +502,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[list[int],] = None, block_shape: Optional[List[int],] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
...@@ -534,7 +534,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -534,7 +534,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = min(sorted_token_ids.size(0), EM = min(sorted_token_ids.size(0),
A.size(0) * top_k * config['BLOCK_SIZE_M']) A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.size(1) if not use_nn_moe else B.size[2], META['BLOCK_SIZE_N']), ) B.size(1) if not use_nn_moe else B.size(2), META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \ if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0: block_shape is not None and block_shape[1] > 0:
...@@ -621,7 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -621,7 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.size(1) if not use_nn_moe else B.size[2], B.size(1) if not use_nn_moe else B.size(2),
B.size(1), B.size(1),
EM, EM,
num_tokens, num_tokens,
...@@ -660,7 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -660,7 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int, def get_config_file_name(E: int,
N: int, N: int,
dtype: Optional[str], dtype: Optional[str],
block_shape: Optional[list[int]] = None, use_nn_moe: Optional[bool] = False) -> str: block_shape: Optional[List[int]] = None, use_nn_moe: Optional[bool] = False) -> str:
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else block_shape_selector = ("" if not block_shape or not all(block_shape) else
...@@ -783,7 +783,7 @@ def get_default_config( ...@@ -783,7 +783,7 @@ def get_default_config(
topk: int, topk: int,
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool, is_marlin: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False, use_nn_moe: Optional[bool]=False,
) -> dict[str, int]: ) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
...@@ -846,7 +846,7 @@ def try_get_optimal_moe_config( ...@@ -846,7 +846,7 @@ def try_get_optimal_moe_config(
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> dict[str, int]: ) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
...@@ -893,9 +893,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -893,9 +893,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled(): # if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax # from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax # return rocm_aiter_topk_softmax
return vllm_topk_softmax return vllm_topk_softmax
...@@ -1033,7 +1033,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1033,7 +1033,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
...@@ -1064,7 +1064,7 @@ def inplace_fused_experts_fake( ...@@ -1064,7 +1064,7 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
pass pass
...@@ -1099,7 +1099,7 @@ def outplace_fused_experts( ...@@ -1099,7 +1099,7 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
...@@ -1130,7 +1130,7 @@ def outplace_fused_experts_fake( ...@@ -1130,7 +1130,7 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1183,7 +1183,7 @@ def fused_experts( ...@@ -1183,7 +1183,7 @@ def fused_experts(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
...@@ -1268,7 +1268,7 @@ def fused_experts_impl( ...@@ -1268,7 +1268,7 @@ def fused_experts_impl(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. # Check constraints.
...@@ -1276,7 +1276,7 @@ def fused_experts_impl( ...@@ -1276,7 +1276,7 @@ def fused_experts_impl(
assert hidden_states.size(1) // 2 == w1.size(2), ( assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch") "Hidden size mismatch")
elif use_nn_moe: elif use_nn_moe:
assert hidden_states.size[1] == w1.size[1], "Hidden size mismatch" assert hidden_states.size(1) == w1.size(1), "Hidden size mismatch"
else: else:
assert hidden_states.size(1) == w1.size(2), ( assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
...@@ -1394,6 +1394,7 @@ def fused_experts_impl( ...@@ -1394,6 +1394,7 @@ def fused_experts_impl(
w1_scale, w1_scale,
w1_zp, w1_zp,
curr_topk_weights, curr_topk_weights,
curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -1432,6 +1433,7 @@ def fused_experts_impl( ...@@ -1432,6 +1433,7 @@ def fused_experts_impl(
w2_scale, w2_scale,
w2_zp, w2_zp,
curr_topk_weights, curr_topk_weights,
curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -1479,7 +1481,7 @@ def fused_moe( ...@@ -1479,7 +1481,7 @@ def fused_moe(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -1525,7 +1527,7 @@ def fused_moe( ...@@ -1525,7 +1527,7 @@ def fused_moe(
a1. a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2. a2.
- block_shape: (Optional[list[int]]): Optional block size for block-wise - block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization. quantization.
Returns: Returns:
...@@ -1577,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1577,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
): ):
super().__init__( super().__init__(
FusedMoEQuantConfig.make( FusedMoEQuantConfig.make(
...@@ -1762,7 +1764,7 @@ def modular_triton_fused_moe( ...@@ -1762,7 +1764,7 @@ def modular_triton_fused_moe(
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[List[int]] = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel( return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
......
...@@ -51,10 +51,10 @@ else: ...@@ -51,10 +51,10 @@ else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
FusedMoEPrepareAndFinalize = None # type: ignore FusedMoEPrepareAndFinalize = None # type: ignore
if is_rocm_aiter_moe_enabled(): # if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk) # rocm_aiter_grouped_topk as grouped_topk)
elif current_platform.is_cpu(): if current_platform.is_cpu():
pass pass
else: else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment