Commit 77753d0a authored by Jared Casper's avatar Jared Casper
Browse files

Small fixes.

parent 5f4ddd9b
......@@ -174,14 +174,14 @@ def initialize_model_parallel(
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None:
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_]]
ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
......
......@@ -4,6 +4,7 @@ import torch
from typing import List, Sequence
from megatron.core.utils import divide
from megatron.core import parallel_state
def split_tensor_along_last_dim(
tensor: torch.Tensor,
......@@ -47,8 +48,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
......@@ -70,7 +71,7 @@ def gather_split_1d_tensor(tensor):
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
......@@ -80,7 +81,7 @@ def gather_split_1d_tensor(tensor):
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
group=parallel_state.get_tensor_model_parallel_group())
return gathered
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment