Commit fee048ff authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.8.5.post1-opt1-wm' into v0.8.5.post1-opt1

parents 1a280da1 35dbdc41
...@@ -88,8 +88,7 @@ def fused_moe_kernel_awq( ...@@ -88,8 +88,7 @@ def fused_moe_kernel_awq(
compute_type: tl.constexpr, compute_type: tl.constexpr,
has_zp: tl.constexpr, has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr):
enable_expert_parallel: int,):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
...@@ -107,17 +106,6 @@ def fused_moe_kernel_awq( ...@@ -107,17 +106,6 @@ def fused_moe_kernel_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m] offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m]
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N)) % N # [block_n] tl.arange(0, BLOCK_SIZE_N)) % N # [block_n]
offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k] offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k]
...@@ -125,6 +113,8 @@ def fused_moe_kernel_awq( ...@@ -125,6 +113,8 @@ def fused_moe_kernel_awq(
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) # [block_m, block_k] offs_k[None, :] * stride_ak) # [block_m, block_k]
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16: if use_int4_w4a16:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63] # [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127] # [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
...@@ -255,8 +245,7 @@ def fused_moe_kernel_gptq_awq( ...@@ -255,8 +245,7 @@ def fused_moe_kernel_gptq_awq(
compute_type: tl.constexpr, compute_type: tl.constexpr,
has_zp: tl.constexpr, has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr):
enable_expert_parallel: int,):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -310,23 +299,14 @@ def fused_moe_kernel_gptq_awq( ...@@ -310,23 +299,14 @@ def fused_moe_kernel_gptq_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16: if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \ b_ptrs = b_ptr + off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
...@@ -467,7 +447,6 @@ def fused_moe_kernel( ...@@ -467,7 +447,6 @@ def fused_moe_kernel(
use_int8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr, per_channel_quant: tl.constexpr,
enable_expert_parallel: int,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -530,24 +509,23 @@ def fused_moe_kernel( ...@@ -530,24 +509,23 @@ def fused_moe_kernel(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
tl.int64) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel: if off_experts == -1:
if off_experts == -1: # -----------------------------------------------------------
# ----------------------------------------------------------- # Write back zeros to the output when the expert is not
# Write back zeros to the output when the expert is not # in the current expert parallel rank.
# in the current expert parallel rank. write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M,
offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type)
BLOCK_SIZE_N, compute_type) return
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) offs_k[None, :] * stride_ak)
...@@ -666,8 +644,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -666,8 +644,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False, use_nn_moe: Optional[bool]=False) -> None:
enable_expert_parallel: int=0,) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -754,7 +731,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -754,7 +731,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp=B_zp is not None, has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
enable_expert_parallel=enable_expert_parallel,
**config, **config,
) )
else: else:
...@@ -793,7 +769,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -793,7 +769,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp=B_zp is not None, has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
enable_expert_parallel=enable_expert_parallel,
**config, **config,
) )
else: else:
...@@ -842,8 +817,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -842,8 +817,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
enable_expert_parallel=enable_expert_parallel,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
) )
...@@ -1665,8 +1638,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1665,8 +1638,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map)) global_num_experts, expert_map))
enable_expert_parallel = (int)(expert_map is not None)
invoke_fused_moe_kernel(qcurr_hidden_states, invoke_fused_moe_kernel(qcurr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
...@@ -1688,8 +1659,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1688,8 +1659,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe)
enable_expert_parallel=enable_expert_parallel)
if activation == "silu": if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
...@@ -1733,8 +1703,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1733,8 +1703,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe)
enable_expert_parallel=enable_expert_parallel)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment