Commit 4ac087d9 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Bugfix] adding chunking mechanism to fused_moe to handle large inputs

parent 4440e8c0
...@@ -30,6 +30,7 @@ if TYPE_CHECKING: ...@@ -30,6 +30,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
...@@ -231,6 +232,9 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -231,6 +232,9 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD": "VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
# Timeout for fetching images when serving multimodal models # Timeout for fetching images when serving multimodal models
# Default is 5 seconds # Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -331,6 +332,31 @@ def get_default_config( ...@@ -331,6 +332,31 @@ def get_default_config(
return config return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
return config
def fused_topk( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -368,14 +394,16 @@ def fused_topk( ...@@ -368,14 +394,16 @@ def fused_topk(
# This is used by the Deepseek-V2 model # This is used by the Deepseek-V2 model
def grouped_topk( def grouped_topk(hidden_states: torch.Tensor,
hidden_states: torch.Tensor, gating_output: torch.Tensor,
gating_output: torch.Tensor, topk: int,
topk: int, renormalize: bool,
renormalize: bool, num_expert_group: int = 0,
num_expert_group: int = 0, topk_group: int = 0):
topk_group: int = 0,
): assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
scores = torch.softmax(gating_output, dim=-1) scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0] num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group, group_scores = scores.view(num_token, num_expert_group,
...@@ -420,25 +448,23 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -420,25 +448,23 @@ def fused_experts(hidden_states: torch.Tensor,
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]
M, _ = hidden_states.shape num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
)
if override_config: config = get_config_func(M)
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1.shape[2],
topk_ids.shape[1],
"float8" if use_fp8 else None)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
...@@ -450,51 +476,78 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -450,51 +476,78 @@ def fused_experts(hidden_states: torch.Tensor,
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16 compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16) if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
if inplace: if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states = hidden_states
dim=1, else:
out=hidden_states) out_hidden_states = torch.empty_like(hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1) for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
invoke_fused_moe_kernel(curr_hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
def fused_moe( def fused_moe(
...@@ -506,6 +559,9 @@ def fused_moe( ...@@ -506,6 +559,9 @@ def fused_moe(
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False, use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
...@@ -528,6 +584,10 @@ def fused_moe( ...@@ -528,6 +584,10 @@ def fused_moe(
Defaults to False. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...@@ -541,8 +601,15 @@ def fused_moe( ...@@ -541,8 +601,15 @@ def fused_moe(
# Check constraints. # Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, if use_grouped_topk:
renormalize) assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
else:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states, return fused_experts(hidden_states,
w1, w1,
w2, w2,
...@@ -554,4 +621,4 @@ def fused_moe( ...@@ -554,4 +621,4 @@ def fused_moe(
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale) a2_scale=a2_scale)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment