Unverified Commit 537c9755 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Minor] Small fix to make distributed init logic in worker looks cleaner (#2905)

parent 786b7f18
...@@ -93,8 +93,6 @@ class Worker: ...@@ -93,8 +93,6 @@ class Worker:
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method) cupy_port, self.distributed_init_method)
if not self.parallel_config.disable_custom_all_reduce:
init_custom_ar()
# Initialize the model. # Initialize the model.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
...@@ -288,6 +286,10 @@ def init_distributed_environment( ...@@ -288,6 +286,10 @@ def init_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment