Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5942af97
Commit
5942af97
authored
Sep 23, 2022
by
Jared Casper
Browse files
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
parent
c2ea914f
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
82 additions
and
126 deletions
+82
-126
megatron/model/distributed.py
megatron/model/distributed.py
+1
-1
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+3
-4
megatron/model/language_model.py
megatron/model/language_model.py
+9
-9
megatron/model/module.py
megatron/model/module.py
+2
-3
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+0
-1
megatron/model/realm_model.py
megatron/model/realm_model.py
+1
-1
megatron/model/t5_model.py
megatron/model/t5_model.py
+5
-7
megatron/model/transformer.py
megatron/model/transformer.py
+24
-25
megatron/model/vision/knn_monitor.py
megatron/model/vision/knn_monitor.py
+2
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+0
-37
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+4
-5
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+4
-6
megatron/p2p_communication.py
megatron/p2p_communication.py
+12
-12
megatron/schedules.py
megatron/schedules.py
+1
-1
megatron/text_generation/api.py
megatron/text_generation/api.py
+1
-1
megatron/text_generation/communication.py
megatron/text_generation/communication.py
+1
-1
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+2
-3
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+2
-1
megatron/training.py
megatron/training.py
+2
-3
megatron/utils.py
megatron/utils.py
+6
-4
No files found.
megatron/model/distributed.py
View file @
5942af97
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
...
...
megatron/model/gpt_model.py
View file @
5942af97
...
@@ -5,8 +5,7 @@
...
@@ -5,8 +5,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.core
import
tensor_parallel
from
megatron
import
core
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.enums
import
AttnMaskType
...
@@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
...
@@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
assert
output
.
dtype
==
torch
.
half
loss
=
core
.
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
else
:
loss
=
core
.
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
# [s b] => [b, s]
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/model/language_model.py
View file @
5942af97
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
core
from
megatron
.core
import
mpu
,
tensor_parallel
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
...
@@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if
args
.
async_tensor_model_parallel_allreduce
or
\
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
sequence_parallel
:
args
.
sequence_parallel
:
input_parallel
=
input_
input_parallel
=
input_
model_parallel
=
core
.
get_tensor_model_parallel_world_size
()
>
1
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
and
not
args
.
sequence_parallel
model_parallel
and
not
args
.
sequence_parallel
else
:
else
:
input_parallel
=
core
.
tensor_parallel
.
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
tensor_parallel
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
async_grad_allreduce
=
False
# Matrix multiply.
# Matrix multiply.
logits_parallel
=
core
.
tensor_parallel
.
linear_with_grad_accumulation_and_async_allreduce
(
logits_parallel
=
tensor_parallel
.
linear_with_grad_accumulation_and_async_allreduce
(
input
=
input_parallel
,
input
=
input_parallel
,
weight
=
word_embeddings_weight
,
weight
=
word_embeddings_weight
,
bias
=
bias
,
bias
=
bias
,
...
@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if
parallel_output
:
if
parallel_output
:
return
logits_parallel
return
logits_parallel
return
core
.
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
return
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
...
@@ -106,7 +106,7 @@ class Pooler(MegatronModule):
...
@@ -106,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
# same pooler is run on all tensor parallel nodes
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
hidden_states
=
core
.
tensor_parallel
.
gather_from_sequence_parallel_region
(
hidden_states
=
tensor_parallel
.
gather_from_sequence_parallel_region
(
hidden_states
,
hidden_states
,
tensor_parallel_output_grad
=
False
)
tensor_parallel_output_grad
=
False
)
...
@@ -146,7 +146,7 @@ class Embedding(MegatronModule):
...
@@ -146,7 +146,7 @@ class Embedding(MegatronModule):
args
=
get_args
()
args
=
get_args
()
# Word embeddings (parallel).
# Word embeddings (parallel).
self
.
word_embeddings
=
core
.
tensor_parallel
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
vocab_size
,
self
.
hidden_size
,
vocab_size
,
self
.
hidden_size
,
init_method
=
self
.
init_method
,
init_method
=
self
.
init_method
,
params_dtype
=
args
.
params_dtype
,
params_dtype
=
args
.
params_dtype
,
...
@@ -229,8 +229,8 @@ class Embedding(MegatronModule):
...
@@ -229,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout.
# Dropout.
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
embeddings
=
core
.
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
embeddings
=
self
.
embedding_dropout
(
embeddings
)
...
...
megatron/model/module.py
View file @
5942af97
...
@@ -7,8 +7,7 @@ from torch.autograd import Variable
...
@@ -7,8 +7,7 @@ from torch.autograd import Variable
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron
import
core
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
...
@@ -77,7 +76,7 @@ class MegatronModule(torch.nn.Module):
...
@@ -77,7 +76,7 @@ class MegatronModule(torch.nn.Module):
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
# stage's weights using all_reduce below.
self
.
word_embeddings
=
core
.
tensor_parallel
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
init_method
=
init_method_normal
(
args
.
init_method_std
),
params_dtype
=
args
.
params_dtype
,
params_dtype
=
args
.
params_dtype
,
...
...
megatron/model/multiple_choice.py
View file @
5942af97
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
torch
import
torch
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
get_args
,
print_rank_last
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.bert_model
import
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
...
...
megatron/model/realm_model.py
View file @
5942af97
...
@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
...
@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
...
...
megatron/model/t5_model.py
View file @
5942af97
...
@@ -4,10 +4,8 @@
...
@@ -4,10 +4,8 @@
import
torch
import
torch
from
megatron
import
(
from
megatron
import
get_args
get_args
,
from
megatron.core
import
tensor_parallel
mpu
)
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
,
get_language_model
from
megatron.model.language_model
import
parallel_lm_logits
,
get_language_model
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.transformer
import
LayerNorm
...
@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
...
@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
lm_labels
)
# [s b] => [b s]
# [s b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
return
lm_loss
...
...
megatron/model/transformer.py
View file @
5942af97
...
@@ -6,10 +6,9 @@ from contextlib import nullcontext
...
@@ -6,10 +6,9 @@ from contextlib import nullcontext
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_timers
,
get_args
from
megatron
import
get_timers
,
get_args
,
core
from
megatron.core
import
get_global_memory_buffer
from
megatron
import
core
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
...
@@ -79,7 +78,7 @@ class ParallelMLP(MegatronModule):
...
@@ -79,7 +78,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
# Project to 4h.
self
.
dense_h_to_4h
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
args
.
ffn_hidden_size
,
args
.
ffn_hidden_size
,
gather_output
=
False
,
gather_output
=
False
,
...
@@ -96,7 +95,7 @@ class ParallelMLP(MegatronModule):
...
@@ -96,7 +95,7 @@ class ParallelMLP(MegatronModule):
self
.
activation_func
=
erf_gelu
self
.
activation_func
=
erf_gelu
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
core
.
tensor_parallel
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
args
.
ffn_hidden_size
,
args
.
ffn_hidden_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
...
@@ -189,7 +188,7 @@ class CoreAttention(MegatronModule):
...
@@ -189,7 +188,7 @@ class CoreAttention(MegatronModule):
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
core
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
world_size
)
world_size
)
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
...
@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
...
@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
get_global_memory_buffer
().
get_tensor
(
matmul_input_buffer
=
mpu
.
get_global_memory_buffer
().
get_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
query_layer
.
dtype
,
"mpu"
)
query_layer
.
dtype
,
"mpu"
)
...
@@ -263,7 +262,7 @@ class CoreAttention(MegatronModule):
...
@@ -263,7 +262,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper.
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
sequence_parallel
:
if
not
self
.
sequence_parallel
:
with
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
else
:
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
...
@@ -327,7 +326,7 @@ class ParallelAttention(MegatronModule):
...
@@ -327,7 +326,7 @@ class ParallelAttention(MegatronModule):
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
core
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
args
.
num_attention_heads
)
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
...
@@ -335,7 +334,7 @@ class ParallelAttention(MegatronModule):
...
@@ -335,7 +334,7 @@ class ParallelAttention(MegatronModule):
# Strided linear layer.
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
3
*
projection_size
,
3
*
projection_size
,
gather_output
=
False
,
gather_output
=
False
,
...
@@ -344,7 +343,7 @@ class ParallelAttention(MegatronModule):
...
@@ -344,7 +343,7 @@ class ParallelAttention(MegatronModule):
**
_args_to_kwargs
())
**
_args_to_kwargs
())
else
:
else
:
assert
attention_type
==
AttnType
.
cross_attn
assert
attention_type
==
AttnType
.
cross_attn
self
.
query
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
projection_size
,
projection_size
,
gather_output
=
False
,
gather_output
=
False
,
...
@@ -353,7 +352,7 @@ class ParallelAttention(MegatronModule):
...
@@ -353,7 +352,7 @@ class ParallelAttention(MegatronModule):
**
_args_to_kwargs
())
**
_args_to_kwargs
())
self
.
key_value
=
core
.
tensor_parallel
.
ColumnParallelLinear
(
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
2
*
projection_size
,
2
*
projection_size
,
gather_output
=
False
,
gather_output
=
False
,
...
@@ -366,7 +365,7 @@ class ParallelAttention(MegatronModule):
...
@@ -366,7 +365,7 @@ class ParallelAttention(MegatronModule):
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
# Output.
# Output.
self
.
dense
=
core
.
tensor_parallel
.
RowParallelLinear
(
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
projection_size
,
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
...
@@ -386,7 +385,7 @@ class ParallelAttention(MegatronModule):
...
@@ -386,7 +385,7 @@ class ParallelAttention(MegatronModule):
value_layer
,
attention_mask
)
value_layer
,
attention_mask
)
return
output_
return
output_
hidden_states
=
core
.
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom_forward
,
custom_forward
,
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
...
@@ -439,7 +438,7 @@ class ParallelAttention(MegatronModule):
...
@@ -439,7 +438,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer
,
(
query_layer
,
key_layer
,
key_layer
,
value_layer
)
=
core
.
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
else
:
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
...
@@ -452,7 +451,7 @@ class ParallelAttention(MegatronModule):
...
@@ -452,7 +451,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
(
key_layer
,
value_layer
)
=
core
.
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
query_layer
,
_
=
self
.
query
(
hidden_states
)
...
@@ -769,7 +768,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -769,7 +768,7 @@ class ParallelTransformer(MegatronModule):
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
# Number of layers.
self
.
num_layers
=
core
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
...
@@ -799,21 +798,21 @@ class ParallelTransformer(MegatronModule):
...
@@ -799,21 +798,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk):
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
# Stage 1: [2, 3] [6, 7]
offset
=
core
.
get_virtual_pipeline_model_parallel_rank
()
*
(
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
core
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
else
:
# Each stage gets a contiguous set of layers.
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
core
.
get_pipeline_model_parallel_world_size
()
>
1
:
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
core
.
get_pipeline_model_parallel_rank
()
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
offset
=
pipeline_rank
*
self
.
num_layers
else
:
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
else
:
offset
=
core
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# When a standalone embedding stage is used (e.g.,
...
@@ -862,7 +861,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -862,7 +861,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
core
.
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
...
@@ -874,7 +873,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -874,7 +873,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
recompute_num_layers
:
if
l
<
self
.
recompute_num_layers
:
hidden_states
=
core
.
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
...
@@ -932,7 +931,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -932,7 +931,7 @@ class ParallelTransformer(MegatronModule):
)
)
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
rng_context
=
core
.
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
else
:
rng_context
=
nullcontext
()
rng_context
=
nullcontext
()
...
...
megatron/model/vision/knn_monitor.py
View file @
5942af97
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch
import
torch
from
megatron
import
print_rank_0
,
get_args
,
mpu
from
megatron
import
print_rank_0
,
get_args
from
megatron.core
import
mpu
from
megatron.data.vit_dataset
import
ClassificationTransform
from
megatron.data.vit_dataset
import
ClassificationTransform
from
megatron.data.image_folder
import
ImageFolder
from
megatron.data.image_folder
import
ImageFolder
...
...
megatron/mpu/__init__.py
deleted
100644 → 0
View file @
c2ea914f
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model parallel utility interface."""
from
.initialize
import
is_unitialized
from
.initialize
import
destroy_model_parallel
from
.initialize
import
get_data_parallel_group
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_embedding_group
from
.initialize
import
get_position_embedding_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_pipeline_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_rank
,
set_tensor_model_parallel_rank
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_rank_in_position_embedding_group
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_data_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/optimizer/distrib_optimizer.py
View file @
5942af97
...
@@ -8,10 +8,9 @@ import torch
...
@@ -8,10 +8,9 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
...
@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)
\
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
=
shard_model_param
.
clone
().
float
()
shard_main_param
=
shard_model_param
.
clone
().
float
()
mpu
.
copy_tensor_model_parallel_attributes
(
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
shard_model_param
,
model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_main_param
,
model_param
)
shard_main_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
shard_model_param
.
shared
=
model_param
.
shared
...
@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[
param_range
.
start
:
param_range
.
end
]
[
param_range
.
start
:
param_range
.
end
]
model_fp32_params_this_group
.
append
(
model_param
)
model_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
shard_model_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
shard_model_param
.
shared
=
model_param
.
shared
...
...
megatron/optimizer/optimizer.py
View file @
5942af97
...
@@ -11,13 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...
@@ -11,13 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
core
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
...
@@ -103,7 +101,7 @@ class MegatronOptimizer(ABC):
...
@@ -103,7 +101,7 @@ class MegatronOptimizer(ABC):
grad
=
param
.
grad
grad
=
param
.
grad
grad_not_none
=
grad
is
not
None
grad_not_none
=
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
is_not_tp_duplicate
=
tensor_parallel
.
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
grads_for_norm
.
append
(
grad
)
...
@@ -528,8 +526,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -528,8 +526,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
main_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
# Copy tensor model parallel attributes.
core
.
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
param
)
if
hasattr
(
param
,
'shared'
):
if
hasattr
(
param
,
'shared'
):
main_param
.
shared
=
param
.
shared
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
# Replace the optimizer params with the new fp32 copy.
...
...
megatron/p2p_communication.py
View file @
5942af97
...
@@ -4,8 +4,8 @@ from functools import reduce
...
@@ -4,8 +4,8 @@ from functools import reduce
import
operator
import
operator
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
core
from
megatron
import
mpu
from
megatron
.core
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
...
@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args
.
scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
not
args
.
sequence_parallel
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
core
.
tensor_parallel
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
tensor_send_prev
=
core
.
tensor_parallel
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
# Send tensors in both the forward and backward directions as appropriate.
if
args
.
use_ring_exchange_p2p
:
if
args
.
use_ring_exchange_p2p
:
...
@@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args
.
scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
not
args
.
sequence_parallel
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
=
mpu
.
make_viewless_tensor
(
tensor_recv_prev
,
tensor_recv_prev
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
requires_grad
=
True
,
keep_graph
=
False
)
keep_graph
=
False
)
if
recv_next
:
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_next
=
mpu
.
make_viewless_tensor
(
tensor_recv_next
,
tensor_recv_next
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
requires_grad
=
True
,
keep_graph
=
False
)
keep_graph
=
False
)
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
...
...
megatron/schedules.py
View file @
5942af97
...
@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...
@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
p2p_communication
from
megatron
import
p2p_communication
from
megatron.core
import
mpu
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
...
...
megatron/text_generation/api.py
View file @
5942af97
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
torch
import
torch
from
megatron
import
mpu
from
megatron
.core
import
mpu
from
.communication
import
broadcast_float_list
from
.communication
import
broadcast_float_list
from
.generation
import
(
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
generate_tokens_probs_and_return_on_first_stage
,
...
...
megatron/text_generation/communication.py
View file @
5942af97
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
torch
import
torch
from
megatron
import
mpu
from
megatron
.core
import
mpu
...
...
megatron/text_generation/forward_step.py
View file @
5942af97
...
@@ -6,9 +6,8 @@ from collections.abc import Iterable
...
@@ -6,9 +6,8 @@ from collections.abc import Iterable
import
torch
import
torch
from
megatron
import
(
from
megatron
import
get_args
get_args
,
from
megatron.core
import
mpu
mpu
)
from
.communication
import
(
from
.communication
import
(
send_to_next_pipeline_rank
,
send_to_next_pipeline_rank
,
recv_from_prev_pipeline_rank_
)
recv_from_prev_pipeline_rank_
)
...
...
megatron/text_generation/generation.py
View file @
5942af97
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_tokenizer
,
mpu
from
megatron
import
get_args
,
get_tokenizer
from
megatron.core
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
.communication
import
(
from
.communication
import
(
copy_from_last_to_first_pipeline_stage
,
copy_from_last_to_first_pipeline_stage
,
...
...
megatron/training.py
View file @
5942af97
...
@@ -19,8 +19,7 @@ from megatron import get_current_global_batch_size
...
@@ -19,8 +19,7 @@ from megatron import get_current_global_batch_size
from
megatron
import
get_num_microbatches
from
megatron
import
get_num_microbatches
from
megatron
import
is_last_rank
from
megatron
import
is_last_rank
from
megatron
import
update_num_microbatches
from
megatron
import
update_num_microbatches
from
megatron
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron
import
core
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
print_rank_last
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
...
@@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# are set for all params so the optimizer can use them.
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
for
param
in
model_module
.
parameters
():
core
.
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
megatron/utils.py
View file @
5942af97
...
@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP
...
@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
amp_C
from
megatron
import
get_args
from
megatron
import
(
from
megatron
import
get_adlr_autoresume
get_args
,
from
megatron
import
mpu
get_adlr_autoresume
,
from
megatron.model.module
import
param_is_not_shared
)
from
megatron.core
import
mpu
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
megatron.model.module
import
param_is_not_shared
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment