Commit 8aa4619f authored by Jared Casper's avatar Jared Casper
Browse files

Cherry-pick fix from development repo.

parent 1b3dfa2f
...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment