Commit 04cd82bd authored by mohammad's avatar mohammad
Browse files

debuged the cpu init for mp=1

parent 42d2be06
......@@ -72,14 +72,6 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight.partition_dim = partition_dim
weight.partition_stride = stride
# If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size()
if world_size == 1:
init_method(weight)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=torch.float,
......@@ -93,6 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment