Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
dd96d402
Commit
dd96d402
authored
Mar 08, 2022
by
Vijay Korthikanti
Browse files
bug fixes
parent
269f28f7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
32 deletions
+53
-32
megatron/model/transformer.py
megatron/model/transformer.py
+41
-17
megatron/mpu/layers.py
megatron/mpu/layers.py
+4
-1
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+8
-14
No files found.
megatron/model/transformer.py
View file @
dd96d402
...
@@ -188,6 +188,7 @@ class ParallelAttention(MegatronModule):
...
@@ -188,6 +188,7 @@ class ParallelAttention(MegatronModule):
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
self
.
params_dtype
=
args
.
params_dtype
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -391,7 +392,11 @@ class ParallelAttention(MegatronModule):
...
@@ -391,7 +392,11 @@ class ParallelAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
if
not
self
.
model_parallel_memory_opt
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# =========================
...
@@ -865,32 +870,51 @@ class ParallelTransformer(MegatronModule):
...
@@ -865,32 +870,51 @@ class ParallelTransformer(MegatronModule):
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
encoder_output
=
mpu
.
scatter_to_sequence_parallel_region
(
encoder_output
)
encoder_output
=
mpu
.
scatter_to_sequence_parallel_region
(
encoder_output
)
# Forward pass.
if
self
.
model_parallel_memory_opt
:
if
self
.
activations_checkpoint_method
is
not
None
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
# Forward pass.
attention_mask
,
if
self
.
activations_checkpoint_method
is
not
None
:
encoder_output
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
enc_dec_attn_mask
)
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
else
:
for
index
in
range
(
self
.
num_layers
):
# Forward pass.
layer
=
self
.
_get_layer
(
index
)
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
layer
(
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
)
enc_dec_attn_mask
=
enc_dec_attn_mask
,
else
:
inference_params
=
inference_params
)
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
# Final layer norm.
if
self
.
post_process
:
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
layer_type
==
LayerType
.
encoder
and
\
if
self
.
layer_type
==
LayerType
.
encoder
and
\
self
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
self
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
self
.
model_parallel_memory_opt
:
self
.
model_parallel_memory_opt
:
output
=
hidden_states
output
=
hidden_states
else
:
else
:
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
)
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
)
...
...
megatron/mpu/layers.py
View file @
dd96d402
...
@@ -215,7 +215,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -215,7 +215,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
model_parallel_memory_opt
=
model_parallel_memory_opt
ctx
.
model_parallel_memory_opt
=
model_parallel_memory_opt
if
model_parallel_memory_opt
:
if
model_parallel_memory_opt
:
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
...
@@ -487,6 +487,8 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -487,6 +487,8 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
dtype
=
args
.
params_dtype
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
model_parallel_memory_opt
)
# Always initialize bias to zero.
# Always initialize bias to zero.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
...
@@ -496,6 +498,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -496,6 +498,7 @@ class RowParallelLinear(torch.nn.Module):
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
...
...
megatron/mpu/mappings.py
View file @
dd96d402
...
@@ -67,7 +67,7 @@ def _split_along_first_dim(input_):
...
@@ -67,7 +67,7 @@ def _split_along_first_dim(input_):
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
(
local_dim_size
)
dim_offset
=
rank
*
(
local_dim_size
)
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
]
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
]
.
contiguous
()
return
output
return
output
...
@@ -106,33 +106,27 @@ def _gather_along_first_dim(input_):
...
@@ -106,33 +106,27 @@ def _gather_along_first_dim(input_):
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
())
requires_grad
=
False
)
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
torch
.
distributed
.
_all_gather_base
(
output
,
input_
,
group
=
get_tensor_model_parallel_group
())
group
=
get_tensor_model_parallel_group
())
return
output
return
output
def
_reduce_scatter_along_first_dim
(
input_
):
def
_reduce_scatter_along_first_dim
(
input_
):
"""Reduce-scatter the input tensor across model parallel group."""
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
assert
dim_size
[
0
]
%
world_size
==
0
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
())
requires_grad
=
False
)
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
# reduce_scatter
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
,
group
=
get_tensor_model_parallel_group
())
group
=
get_tensor_model_parallel_group
())
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment