Commit 23c54150 authored by Hao Lin's avatar Hao Lin
Browse files

Fix test combine args


Signed-off-by: default avatarHao Lin <linhaomails@gmail.com>
parent 8a0ca8e2
......@@ -143,7 +143,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
......
......@@ -127,7 +127,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment