Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
417c7f6a
Commit
417c7f6a
authored
Jul 30, 2020
by
Boris Fomitchev
Browse files
Changes for NeMo/lightning compatibility
Signed-off-by:
Boris Fomitchev
<
bfomitchev@nvidia.com
>
parent
07ebf714
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
27 deletions
+45
-27
megatron/initialize.py
megatron/initialize.py
+0
-6
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+8
-1
megatron/mpu/layers.py
megatron/mpu/layers.py
+12
-8
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+25
-12
No files found.
megatron/initialize.py
View file @
417c7f6a
...
@@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available.
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# This is temporary WAR to make simple case like pytest calling with same args twice
# Need to implement clean factory init.
if
mpu
.
model_parallel_is_initialized
():
return
# Parse args, build tokenizer, and set adlr-autoresume,
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
...
...
megatron/mpu/initialize.py
View file @
417c7f6a
...
@@ -88,13 +88,16 @@ def model_parallel_is_initialized():
...
@@ -88,13 +88,16 @@ def model_parallel_is_initialized():
return
False
return
False
return
True
return
True
def
get_model_parallel_group
():
def
get_model_parallel_group
():
"""Get the model parallel group the caller rank belongs to."""
"""Get the model parallel group the caller rank belongs to."""
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
\
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
\
'model parallel group is not initialized'
'model parallel group is not initialized'
return
_MODEL_PARALLEL_GROUP
return
_MODEL_PARALLEL_GROUP
def
set_model_parallel_group
(
group
):
"""Set model parallel group."""
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
group
def
get_data_parallel_group
():
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
"""Get the data parallel group the caller rank belongs to."""
...
@@ -102,6 +105,10 @@ def get_data_parallel_group():
...
@@ -102,6 +105,10 @@ def get_data_parallel_group():
'data parallel group is not initialized'
'data parallel group is not initialized'
return
_DATA_PARALLEL_GROUP
return
_DATA_PARALLEL_GROUP
def
set_data_parallel_group
(
group
):
"""Set data parallel group."""
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
group
def
set_model_parallel_world_size
(
world_size
):
def
set_model_parallel_world_size
(
world_size
):
"""Set the model parallel size"""
"""Set the model parallel size"""
...
...
megatron/mpu/layers.py
View file @
417c7f6a
...
@@ -127,19 +127,23 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -127,19 +127,23 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings_per_partition
,
0
,
init_method
)
self
.
num_embeddings_per_partition
,
0
,
init_method
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Build the mask.
if
self
.
num_embeddings_per_partition
<
self
.
num_embeddings
:
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
# Build the mask.
(
input_
>=
self
.
vocab_end_index
)
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
# Mask the input.
(
input_
>=
self
.
vocab_end_index
)
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
# Mask the input.
masked_input
[
input_mask
]
=
0
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
# Get the embeddings.
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
,
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
sparse
)
# Mask the output embedding.
# Mask the output embedding.
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
num_embeddings_per_partition
<
self
.
num_embeddings
:
output_parallel
[
input_mask
,
:]
=
0.0
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_from_model_parallel_region
(
output_parallel
)
output
=
reduce_from_model_parallel_region
(
output_parallel
)
return
output
return
output
...
...
megatron/mpu/mappings.py
View file @
417c7f6a
...
@@ -15,20 +15,19 @@
...
@@ -15,20 +15,19 @@
import
torch
import
torch
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_group
,
get_model_parallel_world_size
,
get_model_parallel_rank
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
def
_reduce
(
input_
):
def
_reduce
(
input_
):
"""All-reduce the the input tensor across model parallel group."""
"""All-reduce the the input tensor across model parallel group."""
group
=
get_model_parallel_group
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
torch
.
distributed
.
get_world_size
(
group
=
group
)
==
1
:
if
get_model_parallel_world_size
()
==
1
:
return
input_
return
input_
# All-reduce.
# All-reduce.
torch
.
distributed
.
all_reduce
(
input_
,
group
=
group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_model_parallel_
group
()
)
return
input_
return
input_
...
@@ -36,18 +35,17 @@ def _reduce(input_):
...
@@ -36,18 +35,17 @@ def _reduce(input_):
def
_split
(
input_
):
def
_split
(
input_
):
"""Split the tensor along its last dimension and keep the
"""Split the tensor along its last dimension and keep the
corresponding slice."""
corresponding slice."""
group
=
get_model_parallel_group
()
world_size
=
get_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
torch
.
distributed
.
get_world_size
(
group
=
group
)
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Split along last dimension.
# Split along last dimension.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
input_list
=
split_tensor_along_last_dim
(
input_
,
world_size
)
input_list
=
split_tensor_along_last_dim
(
input_
,
world_size
)
# Note: torch.split does not create contiguous tensors by default.
# Note: torch.split does not create contiguous tensors by default.
rank
=
torch
.
distributed
.
get_rank
(
group
=
group
)
rank
=
get_model_parallel_rank
(
)
output
=
input_list
[
rank
].
contiguous
()
output
=
input_list
[
rank
].
contiguous
()
return
output
return
output
...
@@ -55,16 +53,15 @@ def _split(input_):
...
@@ -55,16 +53,15 @@ def _split(input_):
def
_gather
(
input_
):
def
_gather
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
"""Gather tensors and concatinate along the last dimension."""
group
=
get_model_parallel_group
()
world_size
=
get_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
torch
.
distributed
.
get_world_size
(
group
=
group
)
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Size and dimension.
# Size and dimension.
last_dim
=
input_
.
dim
()
-
1
last_dim
=
input_
.
dim
()
-
1
rank
=
torch
.
distributed
.
get_rank
(
group
=
group
)
rank
=
get_model_parallel_rank
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
tensor_list
[
rank
]
=
input_
...
@@ -79,6 +76,10 @@ def _gather(input_):
...
@@ -79,6 +76,10 @@ def _gather(input_):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
"""Pass the input to the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
input_
return
input_
...
@@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
...
@@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""All-redcue the input from the model parallel region."""
"""All-redcue the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_reduce
(
input_
)
return
_reduce
(
input_
)
...
@@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
...
@@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
(
input_
)
...
@@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
...
@@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
"""Gather the input from model parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
(
input_
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment