Commit 35dbdc41 authored by 王敏's avatar 王敏
Browse files

暂时去掉fused_moe中量化kernel的bug修复

parent cff5452a
......@@ -183,8 +183,7 @@ def fused_moe_kernel_awq(
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr,
enable_expert_parallel: int,):
use_int8_w8a16: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
......@@ -202,17 +201,6 @@ def fused_moe_kernel_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m]
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N)) % N # [block_n]
offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k]
......@@ -220,6 +208,8 @@ def fused_moe_kernel_awq(
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) # [block_m, block_k]
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
......@@ -350,8 +340,7 @@ def fused_moe_kernel_gptq_awq(
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr,
enable_expert_parallel: int,):
use_int8_w8a16: tl.constexpr):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
......@@ -405,23 +394,14 @@ def fused_moe_kernel_gptq_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
......@@ -561,8 +541,7 @@ def fused_moe_kernel(
use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
enable_expert_parallel: int,
per_channel_quant: tl.constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
......@@ -625,13 +604,12 @@ def fused_moe_kernel(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
tl.int64)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
......@@ -642,7 +620,7 @@ def fused_moe_kernel(
return
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
......@@ -761,8 +739,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False,
enable_expert_parallel: int=0,) -> None:
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......@@ -849,7 +826,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
enable_expert_parallel=enable_expert_parallel,
**config,
)
else:
......@@ -888,7 +864,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
enable_expert_parallel=enable_expert_parallel,
**config,
)
else:
......@@ -937,7 +912,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
enable_expert_parallel=enable_expert_parallel,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
......@@ -1750,7 +1724,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
enable_expert_parallel = (int)(expert_map is not None)
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
......@@ -1772,8 +1745,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe,
enable_expert_parallel=enable_expert_parallel)
use_nn_moe=use_nn_moe)
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
......@@ -1834,8 +1806,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
use_nn_moe=use_nn_moe,
enable_expert_parallel=enable_expert_parallel)
use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment