Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0d77c0e9
Commit
0d77c0e9
authored
Mar 07, 2022
by
Vijay Korthikanti
Browse files
refactor to help merge with main
parent
02bb1f5c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
104 deletions
+111
-104
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+2
-2
megatron/model/language_model.py
megatron/model/language_model.py
+18
-8
megatron/mpu/layers.py
megatron/mpu/layers.py
+91
-94
No files found.
megatron/model/fused_layer_norm.py
View file @
0d77c0e9
...
...
@@ -97,8 +97,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
sequence_parallel
=
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
self
.
bias
.
sequence_parallel
=
self
.
sequence_parallel
setattr
(
self
.
weight
,
'
sequence_parallel
'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'
sequence_parallel
'
,
self
.
sequence_parallel
)
def
reset_parameters
(
self
):
...
...
megatron/model/language_model.py
View file @
0d77c0e9
...
...
@@ -26,21 +26,31 @@ from megatron.model.transformer import ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
if
not
args
.
model_parallel_memory_opt
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model_parallel_memory_opt
:
input_parallel
=
input
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
and
\
model_parallel
else
:
input_parallel
=
input_
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
model_parallel_memory_opt
=
False
# Matrix multiply.
if
bias
is
None
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
else
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
,
bias
)
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
model_parallel_memory_opt
)
# Gather if needed.
if
parallel_output
:
return
logits_parallel
...
...
megatron/mpu/layers.py
View file @
0d77c0e9
...
...
@@ -202,56 +202,34 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
Column-parallel l
inear layer execution with asynchronous
all-reduce
execut
ion in backprop.
L
inear layer execution with asynchronous
communication and gradient accumulation
fus
ion in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
model_parallel_memory_opt
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinearWithSequenceParallelism
(
torch
.
autograd
.
Function
):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
total_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
())
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
model_parallel_memory_opt
=
model_parallel_memory_opt
if
model_parallel_memory_opt
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
total_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
())
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
...
...
@@ -261,41 +239,72 @@ class ColumnParallelLinearWithSequenceParallelism(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
total_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
handle
=
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
model_parallel_memory_opt
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
total_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
handle
=
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
handle
.
wait
()
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
if
ctx
.
model_parallel_memory_opt
:
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
])
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
model_parallel_memory_opt
:
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
if
ctx
.
gradient_accumulation_fusion
:
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
if
ctx
.
async_grad_allreducei
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
model_parallel_memory_opt
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -375,37 +384,25 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
self
.
model_parallel_memory_opt
=
(
args
.
model_parallel_memory_opt
and
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
model_parallel_memory_opt
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
if
self
.
async_tensor_model_parallel_allreduce
or
\
self
.
model_parallel_memory_opt
:
input_parallel
=
input_
else
:
# Set up backprop all-reduce.
if
self
.
model_parallel_memory_opt
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
output_parallel
=
ColumnParallelLinearWithSequenceParallelism
.
apply
(
input_
,
self
.
weight
,
bias
)
world_size
=
get_tensor_model_parallel_world_size
()
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
]
*
world_size
,
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
,
self
.
model_parallel_memory_opt
)
if
self
.
gather_output
:
# All-gather across the partitions.
assert
not
self
.
model_parallel_memory_opt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment