Commit e7a963f5 authored by yangql's avatar yangql
Browse files

新增fusemoe手写算子的支持,需要group-gemm包

parent 6880bf15
......@@ -15,7 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.model_executor.layers.fused_moe import fused_experts
os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0')
if os.environ['W4A16_MOE_CUDA'] == '1':
from vllm.model_executor.layers.quantization.utils.fused_moe_cuda import fused_experts_cuda
class MoeWNA16Config(QuantizationConfig):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
......@@ -176,6 +180,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def __init__(self, quant_config: MoeWNA16Config):
self.quant_config = quant_config
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.use_w4a16_cuda = os.environ['W4A16_MOE_CUDA'] == '1'
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
......@@ -329,7 +334,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -347,6 +352,26 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
if self.use_w4a16_cuda:
m = topk_ids.shape[0]
if m <= 64:
return fused_experts_cuda(x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights,
topk_ids,
inplace=True,
use_fp8_w8a8=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=False,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=[0, layer.group_size],
expert_map=expert_map)
return fused_experts(
x,
......
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from grouped_gemm import moe_gemm_w4a16
from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK
import torch.nn.functional as F
logger = init_logger(__name__)
def config_cuda(M):
bw_gemm1_mode_dict = {
1: 83,
2: 77,
3: 32,
4: 38,
5: 38,
6: 87,
7: 82,
8: 42,
9: 83,
10: 42,
11: 42,
12: 87,
13: 42,
14: 38,
15: 42,
16: 42,
17: 42,
18: 87,
19: 87,
20: 83,
21: 83,
22: 83,
23: 83,
24: 27,
25: 42,
26: 83,
27: 38,
28: 42,
29: 42,
30: 38,
31: 42,
32: 38
}
bw_gemm2_mode_dict = {
1: 23,
2: 88,
3: 74,
4: 39,
5: 43,
6: 88,
7: 88,
8: 89,
9: 73,
10: 88,
11: 88,
12: 88,
13: 88,
14: 88,
15: 88,
16: 88,
17: 88,
18: 43,
19: 88,
20: 43,
21: 43,
22: 43,
23: 88,
24: 88,
25: 88,
26: 88,
27: 88,
28: 88,
29: 88,
30: 43,
31: 88,
32: 88
}
k100ai_gemm1_mode_dict = {
1: 79,
2: 34,
3: 34,
4: 34,
6: 34,
8: 34,
16: 34,
24: 34,
32: 34,
}
k100ai_gemm2_mode_dict = {
1: 64,
2: 33,
3: 33,
4: 37,
5: 37,
6: 33,
7: 33,
8: 37,
9: 37,
10: 37,
11: 37,
12: 37,
13: 37,
14: 38,
15: 38,
16: 72,
17: 72,
18: 72,
19: 72,
20: 72,
21: 72,
22: 72,
23: 72,
24: 39,
25: 39,
26: 39,
27: 39,
28: 39,
29: 39,
30: 39,
31: 39,
32: 39,
}
device_name = device_name = current_platform.get_device_name()
if "BW" in device_name:
gemm1_mode_dict = bw_gemm1_mode_dict
gemm2_mode_dict = bw_gemm2_mode_dict
else:
gemm1_mode_dict = k100ai_gemm1_mode_dict
gemm2_mode_dict = k100ai_gemm2_mode_dict
mode_1 = gemm1_mode_dict.get(M, gemm1_mode_dict[32])
mode_2 = gemm2_mode_dict.get(M, gemm2_mode_dict[32])
return mode_1, mode_2
def fused_experts_cuda(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
expert_map: Optional[torch.Tensor] = None,):
if inplace:
fused_experts_impl_cuda(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
expert_map)
return hidden_states
else:
return fused_experts_impl_cuda(hidden_states, w1, w2, topk_weights, topk_ids, False,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
expert_map)
def fused_experts_impl_cuda(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
expert_map: Optional[torch.Tensor] = None,):
# Check constraints.
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = 32768
M = min(num_tokens, CHUNK_SIZE)
# config = get_config_func(M)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
topk_ids.shape[1]]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
# config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, 16, E, expert_map, curr_hidden_states.shape[0]))
mode_1, mode_2 = config_cuda(M)
expert_ids = expert_ids[:num_tokens_post_padded // 16]
moe_gemm_w4a16.gemm1_w4a16(sorted_token_ids.to(torch.uint16), # sorted_token_ids.to(torch.uint16)
curr_hidden_states, # hidden_states
w1, # w1
intermediate_cache1, # gemm1_out
expert_ids, # expert_id_vec
w1_scale, # scale_zero
64, # group_size
topk=topk_ids.shape[1], # topk
mode=mode_1) # mode=gemm1_mode
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
moe_gemm_w4a16.gemm2_w4a16(sorted_token_ids.to(torch.uint16), # sorted_token_ids.to(torch.uint16)
intermediate_cache2, # hidden_states
w2, # w2
intermediate_cache3, # gemm2_out
expert_ids, # expert_id_vec
w2_scale, # scale_zero
curr_topk_weights, # topk_weights
64, # group_size
topk=topk_ids.shape[1], # topk
mode=mode_2) # mode=gemm2_mode
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment