"docs/vscode:/vscode.git/clone" did not exist on "cd9e60c76c776c42431b7ae523fcfe7835546d74"
Unverified Commit 4e121310 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Test] fix function name typo in custom allreduce (#4750)

parent fcc2994b
...@@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): ...@@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank, init_test_distributed_environment(1, world_size, rank,
distributed_init_port) distributed_init_port)
custom_all_reduce.init_custom_all_reduce() custom_all_reduce.init_custom_ar()
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_all_reduce.capture(): with custom_all_reduce.capture():
...@@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): ...@@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port) distributed_init_port)
sz = 1024 sz = 1024
custom_all_reduce.init_custom_all_reduce() custom_all_reduce.init_custom_ar()
fa = custom_all_reduce.get_handle() fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp) out = fa.all_reduce_unreg(inp)
......
...@@ -52,6 +52,10 @@ def init_custom_ar() -> None: ...@@ -52,6 +52,10 @@ def init_custom_ar() -> None:
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.") " is set.")
return return
# we only use a subset of GPUs here
# so we only need to check the nvlink connectivity of these GPUs
num_dev = world_size
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment