Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
8da353a4
Commit
8da353a4
authored
Mar 25, 2025
by
dongcl
Browse files
适配megatron v0.11
parent
b9fdbcfa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
221 additions
and
1 deletion
+221
-1
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+3
-0
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+139
-1
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+79
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
8da353a4
...
@@ -147,9 +147,12 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -147,9 +147,12 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_training
(
self
):
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
initialize_megatron
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
build_tokenizer
)
MegatronAdaptation
.
register
(
'megatron.training.initialize.initialize_megatron'
,
initialize_megatron
)
def
patch_miscellaneous
(
self
):
def
patch_miscellaneous
(
self
):
from
..training.arguments
import
parse_args
from
..training.arguments
import
parse_args
...
...
dcu_megatron/training/arguments.py
View file @
8da353a4
...
@@ -8,7 +8,6 @@ from megatron.training.arguments import (
...
@@ -8,7 +8,6 @@ from megatron.training.arguments import (
_add_learning_rate_args
,
_add_learning_rate_args
,
_add_checkpointing_args
,
_add_checkpointing_args
,
_add_mixed_precision_args
,
_add_mixed_precision_args
,
_add_distributed_args
,
_add_validation_args
,
_add_validation_args
,
_add_data_args
,
_add_data_args
,
_add_autoresume_args
,
_add_autoresume_args
,
...
@@ -88,6 +87,145 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...
@@ -88,6 +87,145 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
return
args
return
args
def
_add_distributed_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'distributed'
)
group
.
add_argument
(
'--tensor-model-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of tensor model parallelism.'
)
group
.
add_argument
(
'--encoder-tensor-model-parallel-size'
,
type
=
int
,
default
=
0
,
help
=
'Degree of tensor model parallelism for the encoder.'
)
group
.
add_argument
(
'--pipeline-model-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of pipeline model parallelism.'
)
group
.
add_argument
(
'--encoder-pipeline-model-parallel-size'
,
type
=
int
,
default
=
0
,
help
=
(
'Degree of pipeline model parallelism in the encoder. This is '
'independent of the amount of pipeline in the decoder.'
))
group
.
add_argument
(
'--pipeline-model-parallel-split-rank'
,
type
=
int
,
default
=
None
,
help
=
(
'Rank where encoder and decoder should be split. '
'Deprecated; use --encoder-pipeline-model-parallel-size instead.'
))
group
.
add_argument
(
'--decoder-first-pipeline-num-layers'
,
type
=
int
,
default
=
None
,
help
=
(
'The number of transformer layers on the first pipeline stage of the decoder. '
'Default None is even split of transformer layers across all pipeline stages'
))
group
.
add_argument
(
'--decoder-last-pipeline-num-layers'
,
type
=
int
,
default
=
None
,
help
=
(
'The number of transformer layers on the last pipeline stage of the decoder. '
'Default None is even split of transformer layers across all pipeline stages'
))
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
group
.
add_argument
(
'--num-layers-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of layers per virtual pipeline stage'
)
group
.
add_argument
(
'--num-virtual-stages-per-pipeline-rank'
,
type
=
int
,
default
=
None
,
help
=
'Number of virtual pipeline stages per pipeline parallelism rank'
)
group
.
add_argument
(
'--microbatch-group-size-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of contiguous microbatches per virtual pipeline stage'
,
dest
=
'microbatch_group_size_per_vp_stage'
)
group
.
add_argument
(
'--no-overlap-p2p-communication'
,
action
=
'store_false'
,
help
=
'overlap pipeline parallel communication with forward and backward chunks in 1F1B'
,
dest
=
'overlap_p2p_comm'
)
group
.
add_argument
(
'--overlap-p2p-communication-warmup-flush'
,
action
=
'store_true'
,
default
=
False
,
help
=
'if set, overlap pipeline parallel communication in warmup and flush'
,
dest
=
'overlap_p2p_comm_warmup_flush'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
group
.
add_argument
(
'--distributed-timeout-minutes'
,
type
=
int
,
default
=
10
,
help
=
'Timeout minutes for torch.distributed.'
)
group
.
add_argument
(
'--overlap-grad-reduce'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, overlap DDP grad reduce.'
)
group
.
add_argument
(
'--defer-embedding-wgrad-compute'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, defers the vocabulary projection linear layer weight'
'gradient compute to pipeline flush.'
,
dest
=
'defer_embedding_wgrad_compute'
)
group
.
add_argument
(
'--wgrad-deferral-limit'
,
type
=
int
,
default
=
0
,
help
=
'Number of micro-batches for which'
'weight gradient computation of vocabulary projection is deferred, defaults to 0 which'
'means all the micro-batches are deferred. Invalid if `defer-embedding-wgrad-compute`'
'is not set'
)
group
.
add_argument
(
'--no-align-grad-reduce'
,
action
=
'store_false'
,
help
=
'If not set, all PP stages will launch gradient reduces simultaneously. '
'Otherwise, each PP stage will independently launch as needed.'
,
dest
=
'align_grad_reduce'
)
group
.
add_argument
(
'--ddp-bucket-size'
,
type
=
int
,
default
=
None
,
help
=
'Bucket size for data-parallel communication'
)
group
.
add_argument
(
'--ddp-average-in-collective'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, average directly in data-parallel communication collective.'
)
group
.
add_argument
(
'--overlap-param-gather'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, overlap param all-gather in distributed optimizer.'
)
group
.
add_argument
(
'--overlap-param-gather-with-optimizer-step'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, overlap param all-gather of first bucket with optimizer step.'
)
group
.
add_argument
(
'--no-align-param-gather'
,
action
=
'store_false'
,
help
=
'If not set, all PP stages will launch param all-gathers simultaneously. '
'Otherwise, each PP stage will independently launch as needed.'
,
dest
=
'align_param_gather'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'If not set, use scatter/gather to optimize communication of tensors in pipeline.'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--use-ring-exchange-p2p'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.'
)
group
.
add_argument
(
'--local-rank'
,
type
=
int
,
default
=
int
(
os
.
getenv
(
'LOCAL_RANK'
,
'0'
)),
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
help
=
'If set to True, initialize_megatron() '
'skips DDP initialization and returns function to '
'complete it instead.Also turns on '
'--use-cpu-initialization flag. This is for '
'external DDP manager.'
)
group
.
add_argument
(
'--account-for-embedding-in-pipeline-split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, *input* embedding layer will be treated as a standard transformer'
'layer in the context of partition and placement for pipeline parallelism.'
)
group
.
add_argument
(
'--account-for-loss-in-pipeline-split'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, loss layer will be treated as a standard transformer'
'layer in the context of partition and placement for pipeline parallelism.'
)
group
.
add_argument
(
'--use-distributed-optimizer'
,
action
=
'store_true'
,
help
=
'Use distributed optimizer.'
)
group
.
add_argument
(
'--num-distributed-optimizer-instances'
,
type
=
int
,
default
=
1
,
help
=
'Number of Distributed Optimizer copies across Data Parallel domain.'
)
group
.
add_argument
(
'--use-torch-fsdp2'
,
action
=
'store_true'
,
help
=
"Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel."
"It is still not in a stable release stage, and may therefore contain bugs or other potential issues."
)
group
.
add_argument
(
'--context-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of context parallelism.'
)
group
.
add_argument
(
'--cp-comm-type'
,
nargs
=
'+'
,
type
=
str
,
default
=
[
"p2p"
],
help
=
'Inter-gpu communication type for context parallelism: '
'p2p, a2a, allgather or a2a+p2p. If a single string is provided, '
'all layers will share the same communication type. Users can also '
'specify separated types for each layer like '
'--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p'
)
group
.
add_argument
(
'--hierarchical-context-parallel-sizes'
,
nargs
=
'+'
,
type
=
int
,
default
=
None
,
help
=
'Degrees of the hierarchical context parallelism. Users should '
'provide a list to specify the sizes for different levels. '
'--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus '
'forms the first level of cp groups and the cp ranks with the same odevity '
'forms the second level of cp groups.'
)
group
.
add_argument
(
'--nccl-communicator-config-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to the yaml file with NCCL communicator '
'configurations. The number of min/max thread groups and thread '
'group cluster size of each communicator can be configured by '
'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.'
)
group
.
add_argument
(
'--use-tp-pp-dp-mapping'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, distributed ranks initialize order is changed '
'from tp-cp-ep-dp-pp to tp-cp-ep-pp-dp.'
)
group
.
add_argument
(
'--replication'
,
action
=
'store_true'
,
default
=
False
,
help
=
"If set, replication of local checkpoints is enabled. "
"Needs to be enabled on all ranks."
)
group
.
add_argument
(
'--replication-jump'
,
default
=
None
,
type
=
int
,
help
=
"Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
"Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
"This flag has an effect only if --replication is used. "
"and must be consistent across all ranks."
)
group
.
add_argument
(
'--replication-factor'
,
default
=
2
,
type
=
int
,
help
=
"Number of machines storing the replica of a given rank's data."
)
group
.
add_argument
(
'--rank'
,
default
=-
1
,
type
=
int
,
help
=
'node rank for distributed training'
)
group
.
add_argument
(
'--world-size'
,
type
=
int
,
default
=
8
,
help
=
'number of nodes for distributed training'
)
group
.
add_argument
(
'--dist-url'
,
help
=
'Which master node url for distributed training.'
)
return
parser
def
_add_tokenizer_args
(
parser
):
def
_add_tokenizer_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'tokenizer'
)
group
=
parser
.
add_argument_group
(
title
=
'tokenizer'
)
group
.
add_argument
(
'--vocab-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--vocab-size'
,
type
=
int
,
default
=
None
,
...
...
dcu_megatron/training/initialize.py
0 → 100644
View file @
8da353a4
"""Megatron initialization."""
import
torch
from
datetime
import
timedelta
from
megatron.training
import
get_args
from
megatron.core
import
mpu
def
_initialize_distributed
(
get_embedding_ranks
,
get_position_embedding_ranks
):
"""Initialize torch.distributed and core model parallel."""
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
if
torch
.
distributed
.
is_initialized
():
if
args
.
rank
==
0
:
print
(
"torch distributed is already initialized, "
"skipping initialization ..."
,
flush
=
True
,
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
if
args
.
rank
==
0
:
print
(
"> initializing torch distributed ..."
,
flush
=
True
)
# Manually set the device ids.
if
device_count
>
0
:
torch
.
cuda
.
set_device
(
args
.
local_rank
)
device_id
=
torch
.
device
(
f
'cuda:
{
args
.
local_rank
}
'
)
else
:
device_id
=
None
# Call the init process
init_process_group_kwargs
=
{
'backend'
:
args
.
distributed_backend
,
'world_size'
:
args
.
world_size
,
'rank'
:
args
.
rank
,
'init_method'
:
args
.
dist_url
,
'timeout'
:
timedelta
(
minutes
=
args
.
distributed_timeout_minutes
),
}
torch
.
distributed
.
init_process_group
(
**
init_process_group_kwargs
)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if
device_count
>
0
:
if
mpu
.
model_parallel_is_initialized
():
print
(
"model parallel is already initialized"
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
,
context_parallel_size
=
args
.
context_parallel_size
,
hierarchical_context_parallel_sizes
=
args
.
hierarchical_context_parallel_sizes
,
expert_model_parallel_size
=
args
.
expert_model_parallel_size
,
num_distributed_optimizer_instances
=
args
.
num_distributed_optimizer_instances
,
expert_tensor_parallel_size
=
args
.
expert_tensor_parallel_size
,
distributed_timeout_minutes
=
args
.
distributed_timeout_minutes
,
nccl_communicator_config_path
=
args
.
nccl_communicator_config_path
,
order
=
'tp-cp-ep-dp-pp'
if
not
args
.
use_tp_pp_dp_mapping
else
'tp-cp-ep-pp-dp'
,
encoder_tensor_model_parallel_size
=
args
.
encoder_tensor_model_parallel_size
,
encoder_pipeline_model_parallel_size
=
args
.
encoder_pipeline_model_parallel_size
,
get_embedding_ranks
=
get_embedding_ranks
,
get_position_embedding_ranks
=
get_position_embedding_ranks
,
)
if
args
.
rank
==
0
:
print
(
f
"> initialized tensor model parallel with size "
f
"
{
mpu
.
get_tensor_model_parallel_world_size
()
}
"
)
print
(
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment