Unverified Commit dcc56d62 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Bugfix] Fix function names in test_block_fp8.py (#16033)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent f15e70d9
......@@ -360,7 +360,7 @@ def fp8_perm(m, idx):
return m[idx, ...]
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
......@@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
return a, a_s, m_indices, inv_perm
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
......@@ -401,7 +401,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
......@@ -419,7 +419,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment