Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3c92fa93
Commit
3c92fa93
authored
Mar 23, 2023
by
Jared Casper
Browse files
Move pipeline parallel functionality into core with associated changes.
parent
0b44909c
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
578 additions
and
325 deletions
+578
-325
examples/detxoify_lm/finetune_gpt.py
examples/detxoify_lm/finetune_gpt.py
+2
-1
megatron/core/enums.py
megatron/core/enums.py
+7
-0
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+41
-8
megatron/core/pipeline_parallel/__init__.py
megatron/core/pipeline_parallel/__init__.py
+1
-0
megatron/core/pipeline_parallel/p2p_communication.py
megatron/core/pipeline_parallel/p2p_communication.py
+187
-136
megatron/core/pipeline_parallel/schedules.py
megatron/core/pipeline_parallel/schedules.py
+282
-158
megatron/core/tensor_parallel/layers.py
megatron/core/tensor_parallel/layers.py
+6
-2
megatron/core/utils.py
megatron/core/utils.py
+15
-0
megatron/model/__init__.py
megatron/model/__init__.py
+0
-1
megatron/model/enums.py
megatron/model/enums.py
+0
-4
megatron/model/retro_transformer.py
megatron/model/retro_transformer.py
+2
-1
megatron/model/transformer.py
megatron/model/transformer.py
+2
-1
megatron/training.py
megatron/training.py
+22
-6
pretrain_bert.py
pretrain_bert.py
+2
-1
pretrain_gpt.py
pretrain_gpt.py
+2
-1
pretrain_ict.py
pretrain_ict.py
+1
-1
pretrain_retro.py
pretrain_retro.py
+2
-1
pretrain_t5.py
pretrain_t5.py
+2
-1
pretrain_vision_classify.py
pretrain_vision_classify.py
+1
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+1
-1
No files found.
examples/detxoify_lm/finetune_gpt.py
View file @
3c92fa93
...
@@ -17,7 +17,8 @@ from megatron import print_rank_0
...
@@ -17,7 +17,8 @@ from megatron import print_rank_0
from
megatron.core
import
mpu
from
megatron.core
import
mpu
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.model
import
GPTModel
from
megatron.core.enums
import
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
megatron/core/enums.py
0 → 100644
View file @
3c92fa93
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
megatron/core/parallel_state.py
View file @
3c92fa93
...
@@ -58,12 +58,40 @@ def initialize_model_parallel(
...
@@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups.
Initialize model data parallel groups.
Arguments:
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
tensor_model_parallel_size (int, default = 1):
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
The number of GPUs to split individual tensors across.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_size (int, default = 1):
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
The number of tensor parallel GPU groups to split the
rank in pipeline with split point.
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...
@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
...
@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
def
set_pipeline_model_parallel_split_rank
(
rank
):
def
set_pipeline_model_parallel_split_rank
(
rank
):
"""Set pipeline model parallel split rank."""
"""Set pipeline model parallel split rank."""
global
_MPU
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
rank
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
rank
def
get_tensor_model_parallel_rank
():
def
get_tensor_model_parallel_rank
():
...
@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
...
@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
get_pipeline_model_parallel_split_rank
():
"""Return pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
...
...
megatron/core/pipeline_parallel/__init__.py
0 → 100644
View file @
3c92fa93
from
.schedules
import
get_forward_backward_func
megatron/p2p_communication.py
→
megatron/
core/pipeline_parallel/
p2p_communication.py
View file @
3c92fa93
This diff is collapsed.
Click to expand it.
megatron/schedules.py
→
megatron/
core/pipeline_parallel/
schedules.py
View file @
3c92fa93
This diff is collapsed.
Click to expand it.
megatron/core/tensor_parallel/layers.py
View file @
3c92fa93
...
@@ -13,6 +13,8 @@ import torch.nn.functional as F
...
@@ -13,6 +13,8 @@ import torch.nn.functional as F
import
torch.nn.init
as
init
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
"""See linear_with_grad_accumulation_and_async_allreduce"""
@
staticmethod
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
):
async_grad_allreduce
,
sequence_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
...
@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return
output
return
output
@
staticmethod
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
...
@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
...
@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
"maximum speedup"
)
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
megatron/core/utils.py
View file @
3c92fa93
...
@@ -20,6 +20,21 @@ def divide(numerator, denominator):
...
@@ -20,6 +20,21 @@ def divide(numerator, denominator):
ensure_divisibility
(
numerator
,
denominator
)
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
return
numerator
//
denominator
def
get_attr_wrapped_model
(
model
,
attr
):
"""Get an attribute from a wrapped model"""
if
isinstance
(
model
,
list
):
raise
RuntimeError
(
"_get_attr_wrapped_model given a list of models"
)
while
not
hasattr
(
model
,
attr
):
if
not
hasattr
(
model
,
"module"
):
raise
RuntimeError
(
f
"_get_attr_wrapped_model couldn't find attribute
{
attr
}
"
)
model
=
model
.
module
return
getattr
(
model
,
attr
)
def
get_model_type
(
model
):
return
get_attr_wrapped_model
(
model
,
'model_type'
)
class
GlobalMemoryBuffer
:
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
"""Global buffer to avoid dynamic memory allocations.
...
...
megatron/model/__init__.py
View file @
3c92fa93
...
@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
...
@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
from
.t5_model
import
T5Model
from
.t5_model
import
T5Model
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
from
.module
import
Float16Module
from
.enums
import
ModelType
megatron/model/enums.py
View file @
3c92fa93
...
@@ -2,10 +2,6 @@
...
@@ -2,10 +2,6 @@
import
enum
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
class
LayerType
(
enum
.
Enum
):
class
LayerType
(
enum
.
Enum
):
encoder
=
1
encoder
=
1
decoder
=
2
decoder
=
2
...
...
megatron/model/retro_transformer.py
View file @
3c92fa93
...
@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
...
@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core
import
utils
as
core_utils
from
megatron.core
import
utils
as
core_utils
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
...
megatron/model/transformer.py
View file @
3c92fa93
...
@@ -9,7 +9,8 @@ import torch.nn.functional as F
...
@@ -9,7 +9,8 @@ import torch.nn.functional as F
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
...
megatron/training.py
View file @
3c92fa93
...
@@ -25,8 +25,8 @@ from megatron import print_rank_last
...
@@ -25,8 +25,8 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.core.enums
import
ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
...
@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
...
@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.
schedu
le
s
import
get_forward_backward_func
from
megatron.
core.pipeline_paral
le
l
import
get_forward_backward_func
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
...
@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
...
@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
return
model
,
optimizer
,
opt_param_scheduler
return
model
,
optimizer
,
opt_param_scheduler
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
):
model
,
optimizer
,
opt_param_scheduler
):
"""Single training step."""
"""Single training step."""
...
@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
...
@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
forward_step_func
=
forward_step_func
,
optimizer
,
fwd_bwd_timers
,
forward_only
=
False
)
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
dtype
=
args
.
params_dtype
,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
grad_scaler
=
optimizer
.
scale_loss
,
sequence_parallel
=
args
.
sequence_parallel
,
forward_only
=
False
,
timers
=
fwd_bwd_timers
)
timers
(
'forward-backward'
).
stop
()
timers
(
'forward-backward'
).
stop
()
# Empty unused memory.
# Empty unused memory.
...
@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
...
@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
loss_dicts
=
forward_backward_func
(
loss_dicts
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
forward_step_func
=
forward_step_func
,
timers
=
None
,
forward_only
=
True
)
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
dtype
=
args
.
params_dtype
,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
sequence_parallel
=
args
.
sequence_parallel
,
forward_only
=
True
,
timers
=
None
)
# Empty unused memory
# Empty unused memory
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
...
...
pretrain_bert.py
View file @
3c92fa93
...
@@ -11,8 +11,9 @@ from megatron import get_args
...
@@ -11,8 +11,9 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
,
ModelType
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_gpt.py
View file @
3c92fa93
...
@@ -9,8 +9,9 @@ from megatron import print_rank_0
...
@@ -9,8 +9,9 @@ from megatron import print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_ict.py
View file @
3c92fa93
...
@@ -13,9 +13,9 @@ from megatron import get_args
...
@@ -13,9 +13,9 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron.core
import
mpu
from
megatron.core
import
mpu
from
megatron.core.enums
import
ModelType
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ModelType
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_retro.py
View file @
3c92fa93
...
@@ -10,7 +10,8 @@ from megatron import get_timers
...
@@ -10,7 +10,8 @@ from megatron import get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
tools.retro.pretraining.retro_dataset
import
get_retro_datasets
from
tools.retro.pretraining.retro_dataset
import
get_retro_datasets
...
...
pretrain_t5.py
View file @
3c92fa93
...
@@ -12,8 +12,9 @@ from megatron import (
...
@@ -12,8 +12,9 @@ from megatron import (
print_rank_0
print_rank_0
)
)
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
T5Model
,
ModelType
from
megatron.model
import
T5Model
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_vision_classify.py
View file @
3c92fa93
...
@@ -6,8 +6,8 @@ import torch
...
@@ -6,8 +6,8 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron.core.enums
import
ModelType
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
...
pretrain_vision_dino.py
View file @
3c92fa93
...
@@ -7,6 +7,7 @@ import numpy as np
...
@@ -7,6 +7,7 @@ import numpy as np
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron.core.enums
import
ModelType
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
...
@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
...
@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment