Commit 031d4ca8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fused_moe.py

parent 77f7bb45
...@@ -109,28 +109,14 @@ def fused_moe_kernel_gptq_awq( ...@@ -109,28 +109,14 @@ def fused_moe_kernel_gptq_awq(
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse. # This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
# group_id = pid // num_pid_in_group group_id = pid // num_pid_in_group
# first_pid_m = group_id * GROUP_SIZE_M first_pid_m = group_id * GROUP_SIZE_M
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
# pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
if GROUP_SIZE_M ==1:
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
...@@ -332,14 +318,28 @@ def fused_moe_kernel( ...@@ -332,14 +318,28 @@ def fused_moe_kernel(
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse. # This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) # num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n # num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group # group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M # first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) # pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m # pid_n = (pid % num_pid_in_group) // group_size_m
if GROUP_SIZE_M ==1:
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment