Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
269f28f7
Commit
269f28f7
authored
Mar 07, 2022
by
Vijay Korthikanti
Browse files
fixes to main merge
parent
6fdbf26b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
12 deletions
+13
-12
megatron/arguments.py
megatron/arguments.py
+5
-0
megatron/model/language_model.py
megatron/model/language_model.py
+3
-5
megatron/model/transformer.py
megatron/model/transformer.py
+0
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-1
megatron/mpu/layers.py
megatron/mpu/layers.py
+4
-4
No files found.
megatron/arguments.py
View file @
269f28f7
...
@@ -291,6 +291,11 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -291,6 +291,11 @@ def parse_args(extra_args_provider=None, defaults={},
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
# model parallel memory optmization
if
args
.
model_parallel_memory_opt
:
assert
not
args
.
async_tensor_model_parallel_allreduce
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/model/language_model.py
View file @
269f28f7
...
@@ -34,23 +34,21 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -34,23 +34,21 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Parallel logits.
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model_parallel_memory_opt
:
args
.
model_parallel_memory_opt
:
input_parallel
=
input
input_parallel
=
input
_
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
model_parallel
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
and
\
model_parallel
else
:
else
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
async_grad_allreduce
=
False
model_parallel_memory_opt
=
False
# Matrix multiply.
# Matrix multiply.
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
model_parallel_memory_opt
)
async_grad_allreduce
,
None
)
# Gather if needed.
# Gather if needed.
if
parallel_output
:
if
parallel_output
:
return
logits_parallel
return
logits_parallel
...
...
megatron/model/transformer.py
View file @
269f28f7
...
@@ -881,7 +881,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -881,7 +881,6 @@ class ParallelTransformer(MegatronModule):
enc_dec_attn_mask
=
enc_dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
inference_params
=
inference_params
)
# Final layer norm.
# Final layer norm.
if
self
.
post_process
:
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
...
@@ -899,5 +898,4 @@ class ParallelTransformer(MegatronModule):
...
@@ -899,5 +898,4 @@ class ParallelTransformer(MegatronModule):
else
:
else
:
output
=
hidden_states
output
=
hidden_states
return
output
return
output
megatron/mpu/__init__.py
View file @
269f28f7
...
@@ -49,7 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
...
@@ -49,7 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsync
Allreduce
from
.layers
import
LinearWithGradAccumulationAndAsync
Communication
from
.layers
import
ColumnParallelLinear
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
VocabParallelEmbedding
...
...
megatron/mpu/layers.py
View file @
269f28f7
...
@@ -299,12 +299,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -299,12 +299,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if
ctx
.
model_parallel_memory_opt
:
if
ctx
.
model_parallel_memory_opt
:
handle
.
wait
()
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
@@ -504,9 +504,9 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -504,9 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert
not
self
.
model_parallel_memory_opt
assert
not
self
.
model_parallel_memory_opt
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsync
Allreduce
.
apply
(
output_parallel
=
LinearWithGradAccumulationAndAsync
Communication
.
apply
(
input_parallel
,
self
.
weight
,
None
,
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
self
.
gradient_accumulation_fusion
,
None
,
None
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment