Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
356eb36a
Commit
356eb36a
authored
May 16, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
aae72ffc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
38 deletions
+34
-38
megatron/arguments.py
megatron/arguments.py
+3
-3
megatron/model/language_model.py
megatron/model/language_model.py
+1
-3
megatron/model/transformer.py
megatron/model/transformer.py
+19
-17
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+11
-14
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+0
-1
No files found.
megatron/arguments.py
View file @
356eb36a
...
...
@@ -302,7 +302,7 @@ def parse_args(extra_args_provider=None, defaults={},
'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled
# model parallel memory opt
i
mization is enabled
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
...
...
@@ -489,7 +489,7 @@ def _add_training_args(parser):
group
.
add_argument
(
'--checkpoint-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activatins to allow for training '
help
=
'Checkpoint activati
o
ns to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is checkpointed, '
...
...
@@ -567,7 +567,7 @@ def _add_training_args(parser):
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--sequence-parallel'
,
action
=
'store_true'
,
help
=
'Enable sequence parallel optmization.'
)
help
=
'Enable sequence parallel opt
i
mization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
...
...
megatron/model/language_model.py
View file @
356eb36a
...
...
@@ -220,11 +220,9 @@ class Embedding(MegatronModule):
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
if
self
.
sequence_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
# Dropout.
if
self
.
sequence_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
mpu
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
...
...
megatron/model/transformer.py
View file @
356eb36a
...
...
@@ -130,21 +130,21 @@ class SwitchMLP(MegatronModule):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [
b
,
s
, h]
b
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
1
)
# hidden_states: [
s
,
b
, h]
s
=
hidden_states
.
size
(
0
)
b
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [
b s
1]
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [
s b
1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [
b
,
s
, h] to [
b*s
, h].
# Converting [
s
,
b
, h] to [
s*b
, h].
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [
b*s
h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [
b*s
1]
max_ind
=
max_ind
.
view
(
-
1
)
# [
b*s
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [
s*b
h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [
s*b
1]
max_ind
=
max_ind
.
view
(
-
1
)
# [
s*b
]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
...
...
@@ -160,14 +160,14 @@ class SwitchMLP(MegatronModule):
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
b
,
s
,
h
)
output_bias_total
=
output_bias_total
.
view
(
b
,
s
,
h
)
output_total
=
output_total
.
view
(
s
,
b
,
h
)
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
return
output_total
,
output_bias_total
class
CoreAttention
(
MegatronModule
):
matmul_input
=
None
matmul_input
_buffer
=
None
def
__init__
(
self
,
layer_number
,
attn_mask_type
=
AttnMaskType
.
padding
):
...
...
@@ -235,8 +235,8 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
if
CoreAttention
.
matmul_input
is
None
:
CoreAttention
.
matmul_input
=
torch
.
empty
(
if
CoreAttention
.
matmul_input
_buffer
is
None
:
CoreAttention
.
matmul_input
_buffer
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
...
...
@@ -245,7 +245,7 @@ class CoreAttention(MegatronModule):
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
CoreAttention
.
matmul_input
,
CoreAttention
.
matmul_input
_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
...
...
@@ -311,7 +311,7 @@ class CoreAttention(MegatronModule):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [
b
,
s
, h]
Self-attention layer takes input with size [
s
,
b
, h]
and returns output of the same size.
"""
...
...
@@ -529,7 +529,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transformer layer takes input with size [
b
,
s
, h] and returns an
Transformer layer takes input with size [
s
,
b
, h] and returns an
output of the same size.
"""
...
...
@@ -603,7 +603,7 @@ class ParallelTransformerLayer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -882,6 +882,8 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# hidden_states: [s, b, h]
# Checks.
if
inference_params
:
assert
self
.
checkpoint_granularity
is
None
,
\
...
...
megatron/mpu/mappings.py
View file @
356eb36a
...
...
@@ -38,7 +38,7 @@ def _split_along_last_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
# Split along last dimension.
...
...
@@ -57,15 +57,16 @@ def _split_along_first_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
assert
dim_size
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
(
local_dim_size
)
dim_offset
=
rank
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
...
...
@@ -77,7 +78,7 @@ def _gather_along_last_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
# Size and dimension.
...
...
@@ -99,7 +100,7 @@ def _gather_along_first_dim(input_):
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
...
...
@@ -116,11 +117,13 @@ def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_
world_size
()
==
1
:
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
assert
dim_size
[
0
]
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
...
...
@@ -130,12 +133,6 @@ def _reduce_scatter_along_first_dim(input_):
return
output
def
_reduce_scatter_along_last_dim
(
input_
):
output
=
_reduce
(
input_
)
output
=
_split_along_last_dim
(
output
)
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
...
...
megatron/optimizer/__init__.py
View file @
356eb36a
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment