"vllm/vscode:/vscode.git/clone" did not exist on "70af44fd1051b629ff22d98ebbba723e47221886"
Commit 99ffef47 authored by zhuwenwen's avatar zhuwenwen
Browse files

修复qwen3-moe的awq配置导致fp16加载错误,修复dpsk-moe手写算子首字耗时增加问题

parent 550a1e5e
...@@ -354,7 +354,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -354,7 +354,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
if self.use_w4a16_cuda: if self.use_w4a16_cuda:
m = topk_ids.shape[0] m = topk_ids.shape[0]
if m <= 64: if m <= 512:
return fused_experts_cuda(x, return fused_experts_cuda(x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
......
...@@ -21,136 +21,194 @@ from grouped_gemm import moe_gemm_w4a16 ...@@ -21,136 +21,194 @@ from grouped_gemm import moe_gemm_w4a16
from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK
import torch.nn.functional as F import torch.nn.functional as F
logger = init_logger(__name__) logger = init_logger(__name__)
device_name = current_platform.get_device_name()
def config_cuda(M): def config_cuda(M):
bw_gemm1_mode_dict = {
1: 83, k100ai_gemm1_m_to_mode_dict = {
2: 77, 1: 'M16N16K256NN1NW8B240',
3: 32, 2: 'M16N16K256NN1NW16B360',
4: 38, 3: 'M16N16K256NN1NW16B360',
5: 38, 4: 'M16N16K256NN1NW16B360',
6: 87, 5: 'M16N16K256NN1NW16B360',
7: 82, 6: 'M16N16K256NN1NW16B360',
8: 42, 7: 'M16N16K256NN1NW16B360',
9: 83, 8: 'M16N16K256NN1NW16B240',
10: 42, 9: 'M16N16K256NN1NW16B360',
11: 42, 10: 'M16N16K256NN1NW16B360',
12: 87, 11: 'M16N16K256NN1NW16B360',
13: 42, 12: 'M16N16K256NN1NW16B360',
14: 38, 13: 'M16N16K256NN1NW16B360',
15: 42, 14: 'M16N16K256NN1NW16B360',
16: 42, 15: 'M16N16K256NN1NW16B360',
17: 42, 16: 'M16N16K256NN1NW16B240',
18: 87, 17: 'M16N16K256NN1NW16B360',
19: 87, 18: 'M16N16K256NN1NW16B360',
20: 83, 19: 'M16N16K256NN1NW16B360',
21: 83, 20: 'M16N16K256NN1NW16B360',
22: 83, 21: 'M16N16K256NN1NW16B360',
23: 83, 22: 'M16N16K256NN1NW16B360',
24: 27, 23: 'M16N16K256NN1NW16B360',
25: 42, 24: 'M16N16K256NN1NW16B360',
26: 83, 25: 'M16N16K256NN1NW16B360',
27: 38, 26: 'M16N16K256NN1NW16B360',
28: 42, 27: 'M16N16K256NN1NW16B240',
29: 42, 28: 'M16N16K256NN1NW16B360',
30: 38, 29: 'M16N16K256NN1NW16B360',
31: 42, 30: 'M16N16K256NN1NW16B360',
32: 38 31: 'M16N16K256NN1NW16B240',
32: 'M16N16K256NN1NW16B360',
64: 'M16N16K256NN1NW16B360',
128: 'M16N16K256NN1NW16B240',
256: 'M16N16K256NN1NW16B360',
512: 'M16N16K128NN1NW8B120',
768: 'M16N32K128NN1NW16B100',
1024: 'M16N32K128NN1NW16B120',
} }
bw_gemm2_mode_dict = {
1: 23,
2: 88, k100ai_gemm2_m_to_mode_dict = {
3: 74, 1: 'M16N32K256NN8NW1B240',
4: 39, 2: 'M16N32K256NN8NW1B360',
5: 43, 3: 'M16N32K256NN8NW1B360',
6: 88, 4: 'M16N32K256NN4NW1B360',
7: 88, 5: 'M16N32K256NN4NW1B360',
8: 89, 6: 'M16N32K256NN4NW1B360',
9: 73, 7: 'M16N32K256NN4NW1B360',
10: 88, 8: 'M16N32K256NN8NW1B360',
11: 88, 9: 'M16N32K256NN8NW1B240',
12: 88, 10: 'M16N32K256NN8NW1B240',
13: 88, 11: 'M16N32K256NN8NW1B240',
14: 88, 12: 'M16N32K256NN8NW1B240',
15: 88, 13: 'M16N32K256NN4NW1B360',
16: 88, 14: 'M16N32K256NN16NW1B360',
17: 88, 15: 'M16N32K256NN16NW1B360',
18: 43, 16: 'M16N32K256NN16NW1B360',
19: 88, 17: 'M16N32K256NN8NW1B240',
20: 43, 18: 'M16N32K256NN8NW1B240',
21: 43, 19: 'M16N32K256NN16NW1B360',
22: 43, 20: 'M16N32K256NN16NW1B360',
23: 88, 21: 'M16N32K256NN16NW1B360',
24: 88, 22: 'M16N32K256NN16NW1B360',
25: 88, 23: 'M16N32K256NN16NW1B360',
26: 88, 24: 'M16N32K256NN16NW1B240',
27: 88, 25: 'M16N32K256NN16NW1B360',
28: 88, 26: 'M16N32K256NN16NW1B360',
29: 88, 27: 'M16N32K256NN16NW1B360',
30: 43, 28: 'M16N32K256NN16NW1B360',
31: 88, 29: 'M16N64K256NN4NW1B240',
32: 88 30: 'M16N32K256NN16NW1B360',
31: 'M16N32K256NN16NW1B360',
32: 'M16N32K256NN16NW1B240',
64: 'M16N32K256NN16NW1B360',
128: 'M16N64K256NN4NW1B240',
256: 'M16N32K256NN16NW1B360',
512: 'M16N64K256NN8NW1B120',
768: 'M16N64K256NN16NW1B360',
1024: 'M16N64K256NN16NW1B360',
} }
k100ai_gemm1_mode_dict = {
1: 79,
2: 34, bw_gemm1_m_to_mode_dict = {
3: 34, 1: 'M16N16K256NN1NW8B360',
4: 34, 2: 'M16N16K256NN1NW4B360',
6: 34, 3: 'M16N32K256NN1NW8B240',
8: 34, 4: 'M16N32K256NN1NW4B360',
16: 34, 5: 'M16N64K256NN1NW4B240',
24: 34, 6: 'M16N32K256NN1NW8B240',
32: 34, 7: 'M16N32K256NN1NW8B360',
8: 'M16N64K256NN1NW4B360',
9: 'M16N64K256NN1NW4B240',
10: 'M16N32K256NN1NW8B240',
11: 'M16N64K256NN1NW4B240',
12: 'M16N64K256NN1NW4B360',
13: 'M16N32K256NN1NW8B240',
14: 'M16N32K256NN1NW8B240',
15: 'M16N32K256NN1NW8B240',
16: 'M16N64K256NN1NW4B360',
17: 'M16N32K256NN1NW8B240',
18: 'M16N64K256NN1NW4B240',
19: 'M16N32K256NN1NW8B240',
20: 'M16N32K256NN1NW8B240',
21: 'M16N32K256NN1NW8B240',
22: 'M16N32K256NN1NW8B240',
23: 'M16N32K256NN1NW8B240',
24: 'M16N64K256NN1NW4B240',
25: 'M16N32K256NN1NW8B240',
26: 'M16N32K256NN1NW8B240',
27: 'M16N32K256NN1NW8B240',
28: 'M16N32K256NN1NW8B240',
29: 'M16N64K256NN1NW4B240',
30: 'M16N64K256NN1NW4B240',
31: 'M16N32K256NN1NW8B240',
32: 'M16N64K256NN1NW4B240',
64: 'M16N32K256NN1NW8B240',
128: 'M16N64K256NN1NW4B240',
256: 'M16N64K256NN1NW4B240',
512: 'M16N64K256NN1NW4B240',
768: 'M16N64K256NN1NW4B240',
1024: 'M16N64K256NN1NW4B240',
} }
k100ai_gemm2_mode_dict = {
1: 64,
2: 33, bw_gemm2_m_to_mode_dict = {
3: 33, 1: 'M16N32K128NN8NW1B240',
4: 37, 2: 'M16N64K256NN8NW1B240',
5: 37, 3: 'M16N64K256NN4NW1B360',
6: 33, 4: 'M16N64K256NN16NW1B240',
7: 33, 5: 'M16N64K256NN8NW1B240',
8: 37, 6: 'M16N64K256NN8NW1B240',
9: 37, 7: 'M16N64K256NN16NW1B240',
10: 37, 8: 'M16N64K256NN8NW1B240',
11: 37, 9: 'M16N64K256NN16NW1B360',
12: 37, 10: 'M16N64K256NN8NW1B240',
13: 37, 11: 'M16N64K256NN16NW1B360',
14: 38, 12: 'M16N64K256NN8NW1B240',
15: 38, 13: 'M16N64K256NN16NW1B240',
16: 72, 14: 'M16N64K256NN16NW1B360',
17: 72, 15: 'M16N64K256NN16NW1B240',
18: 72, 16: 'M16N64K256NN16NW1B240',
19: 72, 17: 'M16N64K256NN8NW1B240',
20: 72, 18: 'M16N64K256NN8NW1B240',
21: 72, 19: 'M16N64K256NN16NW1B240',
22: 72, 20: 'M16N64K256NN8NW1B240',
23: 72, 21: 'M16N64K256NN16NW1B240',
24: 39, 22: 'M16N64K256NN16NW1B360',
25: 39, 23: 'M16N64K256NN16NW1B360',
26: 39, 24: 'M16N64K256NN16NW1B240',
27: 39, 25: 'M16N64K256NN8NW1B240',
28: 39, 26: 'M16N64K256NN16NW1B240',
29: 39, 27: 'M16N64K256NN16NW1B240',
30: 39, 28: 'M16N64K256NN16NW1B240',
31: 39, 29: 'M16N64K256NN16NW1B240',
32: 39, 30: 'M16N64K256NN16NW1B240',
31: 'M16N64K256NN8NW1B240',
32: 'M16N64K256NN16NW1B240',
64: 'M16N64K256NN16NW1B240',
128: 'M16N64K256NN16NW1B240',
256: 'M16N64K256NN16NW1B240',
512: 'M16N64K256NN16NW1B240',
768: 'M16N64K256NN16NW1B240',
1024: 'M16N64K256NN16NW1B240',
} }
device_name = device_name = current_platform.get_device_name() reference_points = [32, 64, 128, 256, 512, 1024]
if "BW" in device_name: NearestM = -1
gemm1_mode_dict = bw_gemm1_mode_dict
gemm2_mode_dict = bw_gemm2_mode_dict if M <= 32:
NearestM = M
else: else:
gemm1_mode_dict = k100ai_gemm1_mode_dict NearestM = min(reference_points, key=lambda x: abs(x - M))
gemm2_mode_dict = k100ai_gemm2_mode_dict
mode_1 = gemm1_mode_dict.get(M, gemm1_mode_dict[32]) if device_name == "K100_AI":
mode_2 = gemm2_mode_dict.get(M, gemm2_mode_dict[32]) mode_1 = k100ai_gemm1_m_to_mode_dict.get(M, k100ai_gemm1_m_to_mode_dict[NearestM])
mode_2 = k100ai_gemm2_m_to_mode_dict.get(M, k100ai_gemm2_m_to_mode_dict[NearestM])
else:
mode_1 = bw_gemm1_m_to_mode_dict.get(M, k100ai_gemm1_m_to_mode_dict[NearestM])
mode_2 = bw_gemm2_m_to_mode_dict.get(M, k100ai_gemm2_m_to_mode_dict[NearestM])
return mode_1, mode_2 return mode_1, mode_2
def fused_experts_cuda(hidden_states: torch.Tensor, def fused_experts_cuda(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -218,18 +276,19 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor, ...@@ -218,18 +276,19 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
E, N, _ = w1.shape E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue: # We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = 32768
M = min(num_tokens, CHUNK_SIZE)
M = num_tokens
topk = topk_ids.shape[1]
mode_1, mode_2 = config_cuda(M)
# config = get_config_func(M) # config = get_config_func(M)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty((M, topk, N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), intermediate_cache2 = torch.empty((M * topk, N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), intermediate_cache3 = torch.empty((M, topk, w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
...@@ -247,63 +306,33 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor, ...@@ -247,63 +306,33 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
else: else:
out_hidden_states = torch.empty_like(hidden_states) out_hidden_states = torch.empty_like(hidden_states)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
topk_ids.shape[1]]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
# config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, 16, E, expert_map, curr_hidden_states.shape[0])) moe_align_block_size(topk_ids, 16, E, expert_map, hidden_states.shape[0]))
moe_gemm_w4a16.gemm1_w4a16(sorted_token_ids, # sorted_token_ids.to(torch.uint16)
mode_1, mode_2 = config_cuda(M) hidden_states, # hidden_states
expert_ids = expert_ids[:num_tokens_post_padded // 16]
moe_gemm_w4a16.gemm1_w4a16(sorted_token_ids.to(torch.uint16), # sorted_token_ids.to(torch.uint16)
curr_hidden_states, # hidden_states
w1, # w1 w1, # w1
intermediate_cache1, # gemm1_out intermediate_cache1, # gemm1_out
num_tokens_post_padded, # 实际专家数
expert_ids, # expert_id_vec expert_ids, # expert_id_vec
w1_scale, # scale_zero w1_scale, # scale_zero
64, # group_size 64, # group_size
topk=topk_ids.shape[1], # topk topk=topk, # topk
mode=mode_1) # mode=gemm1_mode mode=mode_1) # mode=gemm1_mode
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
intermediate_cache1.view(-1, N)) # return intermediate_cache2
moe_gemm_w4a16.gemm2_w4a16(sorted_token_ids, # sorted_token_ids.to(torch.uint16)
moe_gemm_w4a16.gemm2_w4a16(sorted_token_ids.to(torch.uint16), # sorted_token_ids.to(torch.uint16)
intermediate_cache2, # hidden_states intermediate_cache2, # hidden_states
w2, # w2 w2, # w2
intermediate_cache3, # gemm2_out intermediate_cache3, # gemm2_out
num_tokens_post_padded,
expert_ids, # expert_id_vec expert_ids, # expert_id_vec
w2_scale, # scale_zero w2_scale, # scale_zero
curr_topk_weights, # topk_weights topk_weights, # topk_weights
64, # group_size 64, # group_size
topk=topk_ids.shape[1], # topk topk=topk, # topk
mode=mode_2) # mode=gemm2_mode mode=mode_2) # mode=gemm2_mode
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
return out_hidden_states
\ No newline at end of file
...@@ -332,9 +332,13 @@ class Qwen3MoeModel(nn.Module): ...@@ -332,9 +332,13 @@ class Qwen3MoeModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.config = config self.config = config
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
# if self.config.quantization_config["bits"] == 4: # if self.config.quantization_config["bits"] == 4:
# os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
...@@ -352,10 +356,6 @@ class Qwen3MoeModel(nn.Module): ...@@ -352,10 +356,6 @@ class Qwen3MoeModel(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment