Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5942af97
Commit
5942af97
authored
Sep 23, 2022
by
Jared Casper
Browse files
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
parent
c2ea914f
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
82 additions
and
126 deletions
+82
-126
megatron/model/distributed.py
megatron/model/distributed.py
+1
-1
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+3
-4
megatron/model/language_model.py
megatron/model/language_model.py
+9
-9
megatron/model/module.py
megatron/model/module.py
+2
-3
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+0
-1
megatron/model/realm_model.py
megatron/model/realm_model.py
+1
-1
megatron/model/t5_model.py
megatron/model/t5_model.py
+5
-7
megatron/model/transformer.py
megatron/model/transformer.py
+24
-25
megatron/model/vision/knn_monitor.py
megatron/model/vision/knn_monitor.py
+2
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+0
-37
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+4
-5
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+4
-6
megatron/p2p_communication.py
megatron/p2p_communication.py
+12
-12
megatron/schedules.py
megatron/schedules.py
+1
-1
megatron/text_generation/api.py
megatron/text_generation/api.py
+1
-1
megatron/text_generation/communication.py
megatron/text_generation/communication.py
+1
-1
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+2
-3
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+2
-1
megatron/training.py
megatron/training.py
+2
-3
megatron/utils.py
megatron/utils.py
+6
-4
No files found.
megatron/model/distributed.py
View file @
5942af97
...
...
@@ -8,7 +8,7 @@ import torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
.module
import
MegatronModule
...
...
megatron/model/gpt_model.py
View file @
5942af97
...
...
@@ -5,8 +5,7 @@
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
core
from
megatron.core
import
tensor_parallel
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
...
...
@@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
core
.
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
loss
=
core
.
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/model/language_model.py
View file @
5942af97
...
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
core
from
megatron
.core
import
mpu
,
tensor_parallel
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
...
...
@@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
sequence_parallel
:
input_parallel
=
input_
model_parallel
=
core
.
get_tensor_model_parallel_world_size
()
>
1
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
and
not
args
.
sequence_parallel
else
:
input_parallel
=
core
.
tensor_parallel
.
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
tensor_parallel
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
# Matrix multiply.
logits_parallel
=
core
.
tensor_parallel
.
linear_with_grad_accumulation_and_async_allreduce
(
logits_parallel
=
tensor_parallel
.
linear_with_grad_accumulation_and_async_allreduce
(
input
=
input_parallel
,
weight
=
word_embeddings_weight
,
bias
=
bias
,
...
...
@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if
parallel_output
:
return
logits_parallel
return
core
.
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
...
...
@@ -106,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if
self
.
sequence_parallel
:
hidden_states
=
core
.
tensor_parallel
.
gather_from_sequence_parallel_region
(
hidden_states
=
tensor_parallel
.
gather_from_sequence_parallel_region
(
hidden_states
,
tensor_parallel_output_grad
=
False
)
...
...
@@ -146,7 +146,7 @@ class Embedding(MegatronModule):
args
=
get_args
()
# Word embeddings (parallel).
self
.
word_embeddings
=
core
.
tensor_parallel
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
vocab_size
,
self
.
hidden_size
,
init_method
=
self
.
init_method
,
params_dtype
=
args
.
params_dtype
,
...
...
@@ -229,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout.
if
self
.
sequence_parallel
:
embeddings
=
core
.
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
...
megatron/model/module.py
View file @
5942af97
...
...
@@ -7,8 +7,7 @@ from torch.autograd import Variable
from
torch.nn.parameter
import
Parameter
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
core
from
megatron.core
import
mpu
,
tensor_parallel
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
...
...
@@ -77,7 +76,7 @@ class MegatronModule(torch.nn.Module):
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
word_embeddings
=
core
.
tensor_parallel
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
params_dtype
=
args
.
params_dtype
,
...
...
megatron/model/multiple_choice.py
View file @
5942af97
...
...
@@ -5,7 +5,6 @@
import
torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
...
...
megatron/model/realm_model.py
View file @
5942af97
...
...
@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.model
import
BertModel
from
.module
import
MegatronModule
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
...
...
megatron/model/t5_model.py
View file @
5942af97
...
...
@@ -4,10 +4,8 @@
import
torch
from
megatron
import
(
get_args
,
mpu
)
from
megatron
import
get_args
from
megatron.core
import
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
,
get_language_model
from
megatron.model.transformer
import
LayerNorm
...
...
@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
# [s b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
...
...
megatron/model/transformer.py
View file @
5942af97
...
...
@@ -6,10 +6,9 @@ from contextlib import nullcontext
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_timers
,
get_args
from
megatron.core
import
get_global_memory_buffer
from
megatron
import
core
from
megatron
import
get_timers
,
get_args
,
core
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
...
...
@@ -79,7 +78,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self
.
dense_h_to_4h
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
ffn_hidden_size
,
gather_output
=
False
,
...
...
@@ -96,7 +95,7 @@ class ParallelMLP(MegatronModule):
self
.
activation_func
=
erf_gelu
# Project back to h.
self
.
dense_4h_to_h
=
core
.
tensor_parallel
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
args
.
ffn_hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
...
...
@@ -189,7 +188,7 @@ class CoreAttention(MegatronModule):
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
core
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
...
...
@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
get_global_memory_buffer
().
get_tensor
(
matmul_input_buffer
=
mpu
.
get_global_memory_buffer
().
get_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
query_layer
.
dtype
,
"mpu"
)
...
...
@@ -263,7 +262,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
sequence_parallel
:
with
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
...
...
@@ -327,7 +326,7 @@ class ParallelAttention(MegatronModule):
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
core
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
...
...
@@ -335,7 +334,7 @@ class ParallelAttention(MegatronModule):
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
projection_size
,
gather_output
=
False
,
...
...
@@ -344,7 +343,7 @@ class ParallelAttention(MegatronModule):
**
_args_to_kwargs
())
else
:
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
projection_size
,
gather_output
=
False
,
...
...
@@ -353,7 +352,7 @@ class ParallelAttention(MegatronModule):
**
_args_to_kwargs
())
self
.
key_value
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
2
*
projection_size
,
gather_output
=
False
,
...
...
@@ -366,7 +365,7 @@ class ParallelAttention(MegatronModule):
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
# Output.
self
.
dense
=
core
.
tensor_parallel
.
RowParallelLinear
(
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
...
...
@@ -386,7 +385,7 @@ class ParallelAttention(MegatronModule):
value_layer
,
attention_mask
)
return
output_
hidden_states
=
core
.
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom_forward
,
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
...
...
@@ -439,7 +438,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
core
.
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
...
...
@@ -452,7 +451,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
core
.
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
...
...
@@ -769,7 +768,7 @@ class ParallelTransformer(MegatronModule):
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
self
.
num_layers
=
core
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
...
...
@@ -799,21 +798,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
core
.
get_virtual_pipeline_model_parallel_rank
()
*
(
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
core
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
core
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
core
.
get_pipeline_model_parallel_rank
()
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
core
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
...
...
@@ -862,7 +861,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
core
.
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
...
...
@@ -874,7 +873,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
recompute_num_layers
:
hidden_states
=
core
.
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
...
...
@@ -932,7 +931,7 @@ class ParallelTransformer(MegatronModule):
)
if
self
.
sequence_parallel
:
rng_context
=
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
nullcontext
()
...
...
megatron/model/vision/knn_monitor.py
View file @
5942af97
import
torch.nn.functional
as
F
import
torch
from
megatron
import
print_rank_0
,
get_args
,
mpu
from
megatron
import
print_rank_0
,
get_args
from
megatron.core
import
mpu
from
megatron.data.vit_dataset
import
ClassificationTransform
from
megatron.data.image_folder
import
ImageFolder
...
...
megatron/mpu/__init__.py
deleted
100644 → 0
View file @
c2ea914f
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model parallel utility interface."""
from
.initialize
import
is_unitialized
from
.initialize
import
destroy_model_parallel
from
.initialize
import
get_data_parallel_group
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_embedding_group
from
.initialize
import
get_position_embedding_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_pipeline_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_rank
,
set_tensor_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_rank_in_position_embedding_group
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_data_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/optimizer/distrib_optimizer.py
View file @
5942af97
...
...
@@ -8,10 +8,9 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
...
...
@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
=
shard_model_param
.
clone
().
float
()
mpu
.
copy_tensor_model_parallel_attributes
(
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_main_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
...
...
@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[
param_range
.
start
:
param_range
.
end
]
model_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
...
...
megatron/optimizer/optimizer.py
View file @
5942af97
...
...
@@ -11,13 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
core
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
megatron.utils
import
unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
...
...
@@ -103,7 +101,7 @@ class MegatronOptimizer(ABC):
grad
=
param
.
grad
grad_not_none
=
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
is_not_tp_duplicate
=
tensor_parallel
.
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
...
...
@@ -528,8 +526,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
core
.
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
if
hasattr
(
param
,
'shared'
):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
...
...
megatron/p2p_communication.py
View file @
5942af97
...
...
@@ -4,8 +4,8 @@ from functools import reduce
import
operator
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
get_args
,
core
from
megatron
.core
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
...
...
@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
core
.
tensor_parallel
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
tensor_send_prev
=
core
.
tensor_parallel
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
args
.
use_ring_exchange_p2p
:
...
...
@@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
=
mpu
.
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
keep_graph
=
False
)
tensor_recv_prev
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
keep_graph
=
False
)
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
=
mpu
.
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
keep_graph
=
False
)
tensor_recv_next
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
keep_graph
=
False
)
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/schedules.py
View file @
5942af97
...
...
@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
p2p_communication
from
megatron.core
import
mpu
from
megatron.utils
import
unwrap_model
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
...
...
megatron/text_generation/api.py
View file @
5942af97
...
...
@@ -5,7 +5,7 @@
import
torch
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
.communication
import
broadcast_float_list
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
...
...
megatron/text_generation/communication.py
View file @
5942af97
...
...
@@ -5,7 +5,7 @@
import
torch
from
megatron
import
mpu
from
megatron
.core
import
mpu
...
...
megatron/text_generation/forward_step.py
View file @
5942af97
...
...
@@ -6,9 +6,8 @@ from collections.abc import Iterable
import
torch
from
megatron
import
(
get_args
,
mpu
)
from
megatron
import
get_args
from
megatron.core
import
mpu
from
.communication
import
(
send_to_next_pipeline_rank
,
recv_from_prev_pipeline_rank_
)
...
...
megatron/text_generation/generation.py
View file @
5942af97
...
...
@@ -5,7 +5,8 @@
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_tokenizer
,
mpu
from
megatron
import
get_args
,
get_tokenizer
from
megatron.core
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
.communication
import
(
copy_from_last_to_first_pipeline_stage
,
...
...
megatron/training.py
View file @
5942af97
...
...
@@ -19,8 +19,7 @@ from megatron import get_current_global_batch_size
from
megatron
import
get_num_microbatches
from
megatron
import
is_last_rank
from
megatron
import
update_num_microbatches
from
megatron
import
mpu
from
megatron
import
core
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron
import
print_rank_0
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
...
...
@@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
core
.
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
megatron/utils.py
View file @
5942af97
...
...
@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
get_args
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron.model.module
import
param_is_not_shared
from
megatron
import
(
get_args
,
get_adlr_autoresume
,
)
from
megatron.core
import
mpu
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
megatron.model.module
import
param_is_not_shared
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment