Commit 35bea728 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Code review comments - changing parallel test condition


Signed-off-by: default avatarBoris Fomitchev <bfomitchev@nvidia.com>
parent 84a5997a
......@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size())
self.model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
......@@ -127,7 +128,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_):
if self.num_embeddings_per_partition < self.num_embeddings:
if self.model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
......@@ -142,7 +143,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
if self.num_embeddings_per_partition < self.num_embeddings:
if self.model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment