Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
269f28f7
Commit
269f28f7
authored
Mar 07, 2022
by
Vijay Korthikanti
Browse files
fixes to main merge
parent
6fdbf26b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
12 deletions
+13
-12
megatron/arguments.py
megatron/arguments.py
+5
-0
megatron/model/language_model.py
megatron/model/language_model.py
+3
-5
megatron/model/transformer.py
megatron/model/transformer.py
+0
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-1
megatron/mpu/layers.py
megatron/mpu/layers.py
+4
-4
No files found.
megatron/arguments.py
View file @
269f28f7
...
...
@@ -291,6 +291,11 @@ def parse_args(extra_args_provider=None, defaults={},
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
# model parallel memory optmization
if
args
.
model_parallel_memory_opt
:
assert
not
args
.
async_tensor_model_parallel_allreduce
_print_args
(
args
)
return
args
...
...
megatron/model/language_model.py
View file @
269f28f7
...
...
@@ -34,23 +34,21 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model_parallel_memory_opt
:
input_parallel
=
input
input_parallel
=
input
_
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
and
\
model_parallel
else
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
model_parallel_memory_opt
=
False
# Matrix multiply.
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
model_parallel_memory_opt
)
async_grad_allreduce
,
None
)
# Gather if needed.
if
parallel_output
:
return
logits_parallel
...
...
megatron/model/transformer.py
View file @
269f28f7
...
...
@@ -881,7 +881,6 @@ class ParallelTransformer(MegatronModule):
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
...
...
@@ -899,5 +898,4 @@ class ParallelTransformer(MegatronModule):
else
:
output
=
hidden_states
return
output
megatron/mpu/__init__.py
View file @
269f28f7
...
...
@@ -49,7 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsync
Allreduce
from
.layers
import
LinearWithGradAccumulationAndAsync
Communication
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
...
...
megatron/mpu/layers.py
View file @
269f28f7
...
...
@@ -299,12 +299,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if
ctx
.
model_parallel_memory_opt
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -504,9 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert
not
self
.
model_parallel_memory_opt
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsync
Allreduce
.
apply
(
output_parallel
=
LinearWithGradAccumulationAndAsync
Communication
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
self
.
gradient_accumulation_fusion
,
None
,
None
)
# All-reduce across all the partitions.
if
self
.
model_parallel_memory_opt
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment