Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
35bea728
Commit
35bea728
authored
Jul 30, 2020
by
Boris Fomitchev
Browse files
Code review comments - changing parallel test condition
Signed-off-by:
Boris Fomitchev
<
bfomitchev@nvidia.com
>
parent
84a5997a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
megatron/mpu/layers.py
megatron/mpu/layers.py
+4
-3
No files found.
megatron/mpu/layers.py
View file @
35bea728
...
@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
scale_grad_by_freq
=
False
self
.
scale_grad_by_freq
=
False
self
.
sparse
=
False
self
.
sparse
=
False
self
.
_weight
=
None
self
.
_weight
=
None
self
.
model_parallel_size
=
get_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
VocabUtility
.
vocab_range_from_global_vocab_size
(
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_model_parallel_rank
(),
self
.
num_embeddings
,
get_model_parallel_rank
(),
get_
model_parallel_
world_
size
()
)
self
.
model_parallel_size
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
vocab_start_index
self
.
vocab_start_index
...
@@ -127,7 +128,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -127,7 +128,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings_per_partition
,
0
,
init_method
)
self
.
num_embeddings_per_partition
,
0
,
init_method
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
if
self
.
num_embeddings_per_partition
<
self
.
num_embeddings
:
if
self
.
model_parallel_size
>
1
:
# Build the mask.
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
(
input_
>=
self
.
vocab_end_index
)
(
input_
>=
self
.
vocab_end_index
)
...
@@ -142,7 +143,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -142,7 +143,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
sparse
)
# Mask the output embedding.
# Mask the output embedding.
if
self
.
num_embeddings_per_partition
<
self
.
num_embeddings
:
if
self
.
model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
output_parallel
[
input_mask
,
:]
=
0.0
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_from_model_parallel_region
(
output_parallel
)
output
=
reduce_from_model_parallel_region
(
output_parallel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment