Commit b1babea8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fused_moe.py

parent 5c004388
...@@ -704,9 +704,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -704,9 +704,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
get_moe_wna16_block_config(config=config, get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda, use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens, num_valid_tokens=num_tokens,
size_k=A.size[1], size_k=A.size(1),
size_n=B.size[1], size_n=B.size(1),
num_experts=B.size[1], num_experts=B.size(1),
group_size=block_shape[1], group_size=block_shape[1],
real_top_k=top_k, real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"])) block_size_m=config["BLOCK_SIZE_M"]))
...@@ -732,8 +732,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -732,8 +732,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.size[1], B.size(1),
A.size[1], A.size(1),
EM, EM,
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
...@@ -749,7 +749,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -749,7 +749,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.size[1] % config["BLOCK_SIZE_K"] == 0, block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1], group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
...@@ -770,8 +770,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -770,8 +770,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.size[1], B.size(1),
A.size[1], A.size(1),
EM, EM,
num_tokens, num_tokens,
A.stride(0), A.stride(0),
...@@ -787,7 +787,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -787,7 +787,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.size[1] % config["BLOCK_SIZE_K"] == 0, block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1], group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment