Unverified Commit 53a7ebd8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update fused_moe (#553)

parent ad5f04d6
...@@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple ...@@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -108,16 +108,12 @@ def fused_moe_kernel( ...@@ -108,16 +108,12 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + ( a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak offs_k[None, :] * stride_ak)
)
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = ( b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
b_ptr offs_bn[None, :] * stride_bn)
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8: if use_fp8:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
...@@ -133,12 +129,13 @@ def fused_moe_kernel( ...@@ -133,12 +129,13 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
a = tl.load( a = tl.load(a_ptrs,
a_ptrs, mask=token_mask[:, None] &
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0, other=0.0)
) b = tl.load(b_ptrs,
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
...@@ -149,7 +146,9 @@ def fused_moe_kernel( ...@@ -149,7 +146,9 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_fp8: if use_fp8:
...@@ -159,14 +158,15 @@ def fused_moe_kernel( ...@@ -159,14 +158,15 @@ def fused_moe_kernel(
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, num_experts: int topk_ids: torch.Tensor, block_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
...@@ -205,38 +205,32 @@ def moe_align_block_size( ...@@ -205,38 +205,32 @@ def moe_align_block_size(
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty( sorted_ids = torch.empty((max_num_tokens_padded, ),
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device dtype=torch.int32,
) device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty( expert_ids = torch.empty((max_num_m_blocks, ),
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device dtype=torch.int32,
) device=topk_ids.device)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1),
ops.moe_align_block_size( dtype=torch.int32,
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad device=topk_ids.device)
) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, mul_routed_weight: bool, top_k: int,
top_k: int, config: Dict[str, Any], compute_type: tl.dtype,
config: Dict[str, Any], use_fp8: bool) -> None:
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -247,10 +241,8 @@ def invoke_fused_moe_kernel( ...@@ -247,10 +241,8 @@ def invoke_fused_moe_kernel(
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
grid = lambda META: ( grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
...@@ -288,7 +280,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: ...@@ -288,7 +280,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache @functools.lru_cache
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -303,11 +296,11 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, ...@@ -303,11 +296,11 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
json_file_name = get_config_file_name(E, N, dtype) json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
)
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", config_file_path) logger.info("Using configuration from %s for MoE layer.",
config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
...@@ -316,87 +309,97 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, ...@@ -316,87 +309,97 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
return None return None
def fused_moe( def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
) -> Dict[str, int]:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
return config
def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
M, _ = hidden_states.shape
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False, use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None):
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints. # Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if override_config: if override_config:
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -404,48 +407,26 @@ def fused_moe( ...@@ -404,48 +407,26 @@ def fused_moe(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
config = { config = get_default_config(M, E, N, w1.shape[2],
"BLOCK_SIZE_M": 128, topk_ids.shape[1],
"BLOCK_SIZE_N": 64, "float8" if use_fp8 else None)
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4,
}
intermediate_cache1 = torch.empty( intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
(M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype)
) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype)
) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype)
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E topk_ids, config['BLOCK_SIZE_M'], E)
) compute_type = (tl.bfloat16
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel( invoke_fused_moe_kernel(hidden_states,
hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
...@@ -459,13 +440,11 @@ def fused_moe( ...@@ -459,13 +440,11 @@ def fused_moe(
topk_ids.shape[1], topk_ids.shape[1],
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8, use_fp8=use_fp8)
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel( invoke_fused_moe_kernel(intermediate_cache2,
intermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2_scale, a2_scale,
...@@ -479,13 +458,71 @@ def fused_moe( ...@@ -479,13 +458,71 @@ def fused_moe(
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8, use_fp8=use_fp8)
)
if inplace: if inplace:
return torch.sum( return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
out=hidden_states, out=hidden_states)
) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) dim=1)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment