Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
356eb36a
Commit
356eb36a
authored
May 16, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
aae72ffc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
38 deletions
+34
-38
megatron/arguments.py
megatron/arguments.py
+3
-3
megatron/model/language_model.py
megatron/model/language_model.py
+1
-3
megatron/model/transformer.py
megatron/model/transformer.py
+19
-17
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+11
-14
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+0
-1
No files found.
megatron/arguments.py
View file @
356eb36a
...
@@ -302,7 +302,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -302,7 +302,7 @@ def parse_args(extra_args_provider=None, defaults={},
'selective checkpointing granularity'
'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled
# model parallel memory opt
i
mization is enabled
if
args
.
sequence_parallel
:
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
args
.
async_tensor_model_parallel_allreduce
=
False
...
@@ -489,7 +489,7 @@ def _add_training_args(parser):
...
@@ -489,7 +489,7 @@ def _add_training_args(parser):
group
.
add_argument
(
'--checkpoint-granularity'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--checkpoint-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activatins to allow for training '
help
=
'Checkpoint activati
o
ns to allow for training '
'with larger models, sequences, and batch sizes. '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'It is supported at two granularities 1) full: '
'whole transformer layer is checkpointed, '
'whole transformer layer is checkpointed, '
...
@@ -567,7 +567,7 @@ def _add_training_args(parser):
...
@@ -567,7 +567,7 @@ def _add_training_args(parser):
'check persist_ln_hidden_sizes if your hidden '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
'size is supported.'
)
group
.
add_argument
(
'--sequence-parallel'
,
action
=
'store_true'
,
group
.
add_argument
(
'--sequence-parallel'
,
action
=
'store_true'
,
help
=
'Enable sequence parallel optmization.'
)
help
=
'Enable sequence parallel opt
i
mization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
help
=
'Disable fusing gradient accumulation to weight '
...
...
megatron/model/language_model.py
View file @
356eb36a
...
@@ -220,11 +220,9 @@ class Embedding(MegatronModule):
...
@@ -220,11 +220,9 @@ class Embedding(MegatronModule):
if
self
.
fp32_residual_connection
:
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
embeddings
=
embeddings
.
float
()
if
self
.
sequence_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
# Dropout.
# Dropout.
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
else
:
...
...
megatron/model/transformer.py
View file @
356eb36a
...
@@ -130,21 +130,21 @@ class SwitchMLP(MegatronModule):
...
@@ -130,21 +130,21 @@ class SwitchMLP(MegatronModule):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
b
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
1
)
b
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [
b s
1]
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [
s b
1]
# TODO (rprenger) TODO this could be made easier to read
# TODO (rprenger) TODO this could be made easier to read
# Converting [
b
,
s
, h] to [
b*s
, h].
# Converting [
s
,
b
, h] to [
s*b
, h].
# Each vector could be routed differently
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [
b*s
h]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [
s*b
h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [
b*s
1]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [
s*b
1]
max_ind
=
max_ind
.
view
(
-
1
)
# [
b*s
]
max_ind
=
max_ind
.
view
(
-
1
)
# [
s*b
]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
...
@@ -160,14 +160,14 @@ class SwitchMLP(MegatronModule):
...
@@ -160,14 +160,14 @@ class SwitchMLP(MegatronModule):
output_total
=
output_total
*
max_prob
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
b
,
s
,
h
)
output_total
=
output_total
.
view
(
s
,
b
,
h
)
output_bias_total
=
output_bias_total
.
view
(
b
,
s
,
h
)
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
return
output_total
,
output_bias_total
return
output_total
,
output_bias_total
class
CoreAttention
(
MegatronModule
):
class
CoreAttention
(
MegatronModule
):
matmul_input
=
None
matmul_input
_buffer
=
None
def
__init__
(
self
,
layer_number
,
def
__init__
(
self
,
layer_number
,
attn_mask_type
=
AttnMaskType
.
padding
):
attn_mask_type
=
AttnMaskType
.
padding
):
...
@@ -235,8 +235,8 @@ class CoreAttention(MegatronModule):
...
@@ -235,8 +235,8 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
# preallocting input tensor: [b * np, sq, sk]
if
CoreAttention
.
matmul_input
is
None
:
if
CoreAttention
.
matmul_input
_buffer
is
None
:
CoreAttention
.
matmul_input
=
torch
.
empty
(
CoreAttention
.
matmul_input
_buffer
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
2
],
output_size
[
3
],
output_size
[
3
],
...
@@ -245,7 +245,7 @@ class CoreAttention(MegatronModule):
...
@@ -245,7 +245,7 @@ class CoreAttention(MegatronModule):
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
=
torch
.
baddbmm
(
CoreAttention
.
matmul_input
,
CoreAttention
.
matmul_input
_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
...
@@ -311,7 +311,7 @@ class CoreAttention(MegatronModule):
...
@@ -311,7 +311,7 @@ class CoreAttention(MegatronModule):
class
ParallelAttention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [
b
,
s
, h]
Self-attention layer takes input with size [
s
,
b
, h]
and returns output of the same size.
and returns output of the same size.
"""
"""
...
@@ -529,7 +529,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
...
@@ -529,7 +529,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class
ParallelTransformerLayer
(
MegatronModule
):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
"""A single transformer layer.
Transformer layer takes input with size [
b
,
s
, h] and returns an
Transformer layer takes input with size [
s
,
b
, h] and returns an
output of the same size.
output of the same size.
"""
"""
...
@@ -603,7 +603,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -603,7 +603,7 @@ class ParallelTransformerLayer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -882,6 +882,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -882,6 +882,8 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
):
# hidden_states: [s, b, h]
# Checks.
# Checks.
if
inference_params
:
if
inference_params
:
assert
self
.
checkpoint_granularity
is
None
,
\
assert
self
.
checkpoint_granularity
is
None
,
\
...
...
megatron/mpu/mappings.py
View file @
356eb36a
...
@@ -38,7 +38,7 @@ def _split_along_last_dim(input_):
...
@@ -38,7 +38,7 @@ def _split_along_last_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Split along last dimension.
# Split along last dimension.
...
@@ -57,15 +57,16 @@ def _split_along_first_dim(input_):
...
@@ -57,15 +57,16 @@ def _split_along_first_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Split along first dimension.
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
assert
dim_size
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size
=
dim_size
//
world_size
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
(
local_dim_size
)
dim_offset
=
rank
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
...
@@ -77,7 +78,7 @@ def _gather_along_last_dim(input_):
...
@@ -77,7 +78,7 @@ def _gather_along_last_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Size and dimension.
# Size and dimension.
...
@@ -99,7 +100,7 @@ def _gather_along_first_dim(input_):
...
@@ -99,7 +100,7 @@ def _gather_along_first_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
=
list
(
input_
.
size
())
...
@@ -116,11 +117,13 @@ def _reduce_scatter_along_first_dim(input_):
...
@@ -116,11 +117,13 @@ def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_
world_size
()
==
1
:
if
world_size
==
1
:
return
input_
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
assert
dim_size
[
0
]
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
...
@@ -130,12 +133,6 @@ def _reduce_scatter_along_first_dim(input_):
...
@@ -130,12 +133,6 @@ def _reduce_scatter_along_first_dim(input_):
return
output
return
output
def
_reduce_scatter_along_last_dim
(
input_
):
output
=
_reduce
(
input_
)
output
=
_split_along_last_dim
(
output
)
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
"""Pass the input to the model parallel region."""
...
...
megatron/optimizer/__init__.py
View file @
356eb36a
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
apex.optimizers
import
FusedSGD
as
SGD
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment