Commit 77753d0a authored by Jared Casper's avatar Jared Casper
Browse files

Small fixes.

parent 5f4ddd9b
...@@ -174,14 +174,14 @@ def initialize_model_parallel( ...@@ -174,14 +174,14 @@ def initialize_model_parallel(
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]] position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank_ is not None: if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0], embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_], ranks[pipeline_model_parallel_split_rank],
ranks[-1]] ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks: if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_]] ranks[pipeline_model_parallel_split_rank]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
position_embedding_ranks = ranks position_embedding_ranks = ranks
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from typing import List, Sequence from typing import List, Sequence
from megatron.core.utils import divide from megatron.core.utils import divide
from megatron.core import parallel_state
def split_tensor_along_last_dim( def split_tensor_along_last_dim(
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -47,8 +48,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): ...@@ -47,8 +48,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" """
partition_size = torch.numel(tensor) // \ partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size() parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank() start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
if new_buffer: if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype, data = torch.empty(partition_size, dtype=tensor.dtype,
...@@ -70,7 +71,7 @@ def gather_split_1d_tensor(tensor): ...@@ -70,7 +71,7 @@ def gather_split_1d_tensor(tensor):
tensor: A Tensor or view of this rank's portion of the data. tensor: A Tensor or view of this rank's portion of the data.
""" """
numel_gathered = torch.numel(tensor) * \ numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size() parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
...@@ -80,7 +81,7 @@ def gather_split_1d_tensor(tensor): ...@@ -80,7 +81,7 @@ def gather_split_1d_tensor(tensor):
# This API calls directly NCCL all-gather versus the former does # This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down. # internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor, torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group()) group=parallel_state.get_tensor_model_parallel_group())
return gathered return gathered
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment