Commit e57e9270 authored by lishen's avatar lishen
Browse files

量化测试代码修改对应的tests修改

parent 830124e1
...@@ -6,7 +6,7 @@ import torch.distributed as dist ...@@ -6,7 +6,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back, hash_tensor from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_pg_back, hash_tensor
# Test compatibility with low latency functions # Test compatibility with low latency functions
import test_low_latency import test_low_latency
...@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
hash_value += hash_tensor(recv_x[0]) hash_value += hash_tensor(recv_x[0])
hash_value += hash_tensor(recv_x[1]) hash_value += hash_tensor(recv_x[1])
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks # Checks
recv_gbl_rank_prefix_sum = handle[-4] recv_gbl_rank_prefix_sum = handle[-4]
...@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if not is_rand: if not is_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum) check_data(recv_x, recv_gbl_rank_prefix_sum)
......
...@@ -5,7 +5,7 @@ import torch.distributed as dist ...@@ -5,7 +5,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import deep_ep import deep_ep
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_pg_back
# Test compatibility with low latency functions # Test compatibility with low latency functions
import test_low_latency import test_low_latency
...@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks # Checks
rank_prefix_matrix = handle[0] rank_prefix_matrix = handle[0]
...@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args.update({'num_worst_tokens': num_worst_tokens}) dispatch_args.update({'num_worst_tokens': num_worst_tokens})
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args) recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x recv_worst_x = per_token_cast_pg_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
assert len(empty_list) == 0 assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0) assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0) assert num_worst_tokens == recv_worst_topk_idx.size(0)
...@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand: if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix) check_data(recv_x, rank_prefix_matrix)
......
...@@ -4,7 +4,7 @@ import torch.distributed as dist ...@@ -4,7 +4,7 @@ import torch.distributed as dist
from functools import partial from functools import partial
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8)) # print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
# return # return
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone() if dispatch_use_fp8 else packed_recv_x.clone()
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n") # print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n") # print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
...@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0): for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices # Check expert indices
......
...@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor): ...@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
x_scales = x_scales.view(x.size(0), -1, 1) x_scales = x_scales.view(x.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous() return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_pc_back(x_int8: torch.Tensor, x_scales: torch.Tensor): def per_token_cast_pc_back(x: torch.Tensor, x_scales: torch.Tensor):
if x_int8.numel() == 0: if x.numel() == 0:
return x_int8.to(torch.bfloat16) return x.to(torch.bfloat16)
assert x_int8.dim() == 2 assert x.dim() == 2
m, n = x_int8.shape m, n = x.shape
aligned_n = align_up(n, 128) aligned_n = align_up(n, 128)
x_int8_padded = torch.nn.functional.pad(x_int8, (0, aligned_n - n), mode='constant', value=0) x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1) x_fp32_padded = x_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32) x_scales = x_scales.view(m, -1, 1).to(torch.float32)
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n) x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous() return x_deq[:, :n].to(torch.bfloat16).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment