"vllm/vscode:/vscode.git/clone" did not exist on "428dd1445ee3750099967084725849c4920721a5"
Unverified Commit 94096a47 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[UX] Separate marlin moe config logic from triton moe (#23006)

parent a258ad8b
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ.""" """Fused MoE utilities for GPTQ."""
import functools
from typing import Optional from typing import Optional
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
moe_align_block_size, try_get_optimal_moe_config)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add) marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
...@@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
N = w2.shape[1] * 16 N = w2.shape[1] * 16
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
get_config_func = functools.partial( # M block size selection logic
try_get_optimal_moe_config, # TODO: tune this further for specific models
w1.shape, for block_size_m in [8, 16, 32, 48, 64]:
w2.shape, if M * topk / E / block_size_m < 0.9:
topk_ids.shape[1], break
None,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
......
...@@ -801,7 +801,6 @@ def get_default_config( ...@@ -801,7 +801,6 @@ def get_default_config(
K: int, K: int,
topk: int, topk: int,
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> dict[str, int]: ) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
...@@ -832,11 +831,6 @@ def get_default_config( ...@@ -832,11 +831,6 @@ def get_default_config(
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else: else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif is_marlin:
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break
return {"BLOCK_SIZE_M": block_size_m}
elif M <= E: elif M <= E:
config = { config = {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
...@@ -860,7 +854,6 @@ def try_get_optimal_moe_config( ...@@ -860,7 +854,6 @@ def try_get_optimal_moe_config(
top_k: int, top_k: int,
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
is_marlin: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> dict[str, int]: ) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
...@@ -883,7 +876,7 @@ def try_get_optimal_moe_config( ...@@ -883,7 +876,7 @@ def try_get_optimal_moe_config(
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin, block_shape) block_shape)
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment