Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
lim
dcu_megatron_core_v0.15.0
Commits
01bcbb1e
Commit
01bcbb1e
authored
Feb 12, 2026
by
lim
Browse files
Initial commit
parent
187361d1
Pipeline
#3395
canceled with stages
Changes
371
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2339 additions
and
0 deletions
+2339
-0
dcu_megatron/adaptor/features_manager/tensor_parallel/__pycache__/parallel_linear_feature.cpython-310.pyc
...allel/__pycache__/parallel_linear_feature.cpython-310.pyc
+0
-0
dcu_megatron/adaptor/features_manager/tensor_parallel/parallel_linear_feature.py
...atures_manager/tensor_parallel/parallel_linear_feature.py
+41
-0
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+504
-0
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+196
-0
dcu_megatron/core/__pycache__/parallel_state.cpython-310.pyc
dcu_megatron/core/__pycache__/parallel_state.cpython-310.pyc
+0
-0
dcu_megatron/core/__pycache__/quantization_utils.cpython-310.pyc
...atron/core/__pycache__/quantization_utils.cpython-310.pyc
+0
-0
dcu_megatron/core/datasets/VL_helpers.py
dcu_megatron/core/datasets/VL_helpers.py
+569
-0
dcu_megatron/core/datasets/image_processing.py
dcu_megatron/core/datasets/image_processing.py
+102
-0
dcu_megatron/core/dist_checkpoint/exchange_utils.py
dcu_megatron/core/dist_checkpoint/exchange_utils.py
+111
-0
dcu_megatron/core/dist_checkpoint/strategies/cached_metadata_filesystem_reader.py
...heckpoint/strategies/cached_metadata_filesystem_reader.py
+28
-0
dcu_megatron/core/dist_checkpoint/strategies/filesystem_async.py
...atron/core/dist_checkpoint/strategies/filesystem_async.py
+90
-0
dcu_megatron/core/dist_checkpoint/strategies/fully_parallel.py
...egatron/core/dist_checkpoint/strategies/fully_parallel.py
+50
-0
dcu_megatron/core/dist_checkpoint/strategies/torch.py
dcu_megatron/core/dist_checkpoint/strategies/torch.py
+328
-0
dcu_megatron/core/dist_checkpoint/validation.py
dcu_megatron/core/dist_checkpoint/validation.py
+45
-0
dcu_megatron/core/distributed/__pycache__/data_parallel_base.cpython-310.pyc
...istributed/__pycache__/data_parallel_base.cpython-310.pyc
+0
-0
dcu_megatron/core/distributed/__pycache__/param_and_grad_buffer.cpython-310.pyc
...ributed/__pycache__/param_and_grad_buffer.cpython-310.pyc
+0
-0
dcu_megatron/core/distributed/__pycache__/power_sgd.cpython-310.pyc
...on/core/distributed/__pycache__/power_sgd.cpython-310.pyc
+0
-0
dcu_megatron/core/distributed/data_parallel_base.py
dcu_megatron/core/distributed/data_parallel_base.py
+7
-0
dcu_megatron/core/distributed/distributed_data_parallel.py
dcu_megatron/core/distributed/distributed_data_parallel.py
+37
-0
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+231
-0
No files found.
dcu_megatron/adaptor/features_manager/tensor_parallel/__pycache__/parallel_linear_feature.cpython-310.pyc
0 → 100644
View file @
01bcbb1e
File added
dcu_megatron/adaptor/features_manager/tensor_parallel/parallel_linear_feature.py
0 → 100644
View file @
01bcbb1e
from
argparse
import
ArgumentParser
from
..feature
import
AbstractFeature
class
ParallelLinearFeature
(
AbstractFeature
):
def
__init__
(
self
):
super
().
__init__
(
'parallel-linear-impl'
)
def
register_args
(
self
,
parser
:
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
self
.
feature_name
)
group
.
add_argument
(
'--parallel-linear-impl'
,
type
=
str
,
default
=
None
,
choices
=
[
'flux'
],
help
=
'Specify the method to replace ColumnParallelLinear/RowParallelLinear'
)
group
.
add_argument
(
'--save-flux-gather-input'
,
action
=
'store_true'
,
default
=
False
,
help
=
'use gathered input of AGKernel for wgrad computation'
)
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to transpose weight when using flux kernel'
)
group
.
add_argument
(
'--disable-bw-flux-gemmrs-op'
,
action
=
'store_false'
,
default
=
True
,
dest
=
'enable_bw_flux_gemmrs_op'
,
help
=
'Do not use flux.GemmRS in backward pass'
)
def
validate_args
(
self
,
args
):
if
args
.
parallel_linear_impl
==
"flux"
and
args
.
transformer_impl
!=
'transformer_engine'
:
raise
AssertionError
(
'flux is only supported with transformer_engine implementation'
)
def
register_patches
(
self
,
patch_manager
,
args
):
# flux
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_layer_with_flux_spec
if
args
.
parallel_linear_impl
==
'flux'
:
patch_manager
.
register_patch
(
"megatron.core.extensions.transformer_engine.TEColumnParallelLinear"
,
FluxColumnParallelLinear
)
patch_manager
.
register_patch
(
"megatron.core.extensions.transformer_engine.TERowParallelLinear"
,
FluxRowParallelLinear
)
patch_manager
.
register_patch
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
dcu_megatron/adaptor/megatron_adaptor.py
0 → 100644
View file @
01bcbb1e
This diff is collapsed.
Click to expand it.
dcu_megatron/adaptor/patch_utils.py
0 → 100644
View file @
01bcbb1e
import
importlib
import
sys
import
types
def
get_func_name
(
func
):
if
isinstance
(
func
,
str
):
return
func
return
'.'
.
join
((
func
.
__module__
,
func
.
__qualname__
))
def
dummy_function_wrapper
(
func_name
):
def
dummy_function
(
*
args
,
**
kwargs
):
raise
RuntimeError
(
'function {} no exist'
.
format
(
func_name
))
return
dummy_function
class
Patch
:
def
__init__
(
self
,
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
split_name
=
orig_func_or_cls_name
.
rsplit
(
'.'
,
1
)
if
len
(
split_name
)
==
1
:
self
.
orig_module_name
,
self
.
orig_func_or_cls_name
=
orig_func_or_cls_name
,
None
else
:
self
.
orig_module_name
,
self
.
orig_func_or_cls_name
=
split_name
self
.
orig_module
=
None
self
.
orig_func_or_cls
=
None
self
.
patch_func_or_cls
=
None
self
.
wrappers
=
[]
self
.
remove_origin_wrappers
=
False
if
(
new_func_or_cls
is
None
and
not
remove_origin_wrappers
):
new_func_or_cls
=
dummy_function_wrapper
(
orig_func_or_cls_name
)
self
.
set_patch_func
(
new_func_or_cls
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
self
.
is_applied
=
False
self
.
create_dummy
=
create_dummy
@
property
def
orig_func_or_cls_id
(
self
):
return
id
(
self
.
orig_func_or_cls
)
@
property
def
patch_func_id
(
self
):
return
id
(
self
.
patch_func_or_cls
)
@
staticmethod
def
remove_wrappers
(
module
,
func_name
,
func
):
while
True
:
if
(
module
.
__dict__
and
func_name
in
module
.
__dict__
and
isinstance
(
module
.
__dict__
[
func_name
],
(
staticmethod
,
classmethod
))
):
func
=
module
.
__dict__
[
func_name
].
__func__
if
hasattr
(
func
,
'__wrapped__'
)
and
func
.
__wrapped__
is
not
None
:
func
=
func
.
__wrapped__
elif
hasattr
(
func
,
'__closure__'
)
and
func
.
__closure__
is
not
None
:
func
=
func
.
__closure__
[
0
].
cell_contents
else
:
break
return
func
def
set_patch_func
(
self
,
new_func_or_cls
=
None
,
force_patch
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
remove_origin_wrappers
:
self
.
remove_origin_wrappers
=
True
else
:
assert
new_func_or_cls
is
not
None
if
new_func_or_cls
is
None
:
return
if
(
apply_wrapper
or
(
hasattr
(
new_func_or_cls
,
'__name__'
)
and
new_func_or_cls
.
__name__
.
endswith
((
'wrapper'
,
'decorator'
)))
):
for
wrapper
in
self
.
wrappers
:
if
id
(
wrapper
)
==
id
(
new_func_or_cls
):
raise
RuntimeError
(
f
"wrapper
{
getattr
(
new_func_or_cls
,
'__name__'
)
}
has already been applied"
)
self
.
wrappers
.
append
(
new_func_or_cls
)
else
:
if
(
self
.
patch_func_or_cls
and
not
force_patch
and
id
(
new_func_or_cls
)
!=
id
(
self
.
patch_func_or_cls
)
):
raise
RuntimeError
(
'the patch of {} exist !'
.
format
(
self
.
orig_func_or_cls_name
))
self
.
patch_func_or_cls
=
new_func_or_cls
self
.
is_applied
=
False
def
apply_patch
(
self
):
if
self
.
is_applied
:
return
self
.
orig_module
,
self
.
orig_func_or_cls
=
Patch
.
parse_path
(
self
.
orig_module_name
,
self
.
orig_func_or_cls_name
,
self
.
create_dummy
)
final_patch_func_or_cls
=
self
.
orig_func_or_cls
if
self
.
patch_func_or_cls
is
not
None
:
final_patch_func_or_cls
=
self
.
patch_func_or_cls
# remove original wrappers
if
self
.
remove_origin_wrappers
:
final_patch_func_or_cls
=
self
.
remove_wrappers
(
self
.
orig_module
,
self
.
orig_func_or_cls_name
,
final_patch_func_or_cls
)
# add new wrappers
for
wrapper
in
self
.
wrappers
:
final_patch_func_or_cls
=
wrapper
(
final_patch_func_or_cls
)
if
self
.
orig_func_or_cls_name
is
not
None
:
setattr
(
self
.
orig_module
,
self
.
orig_func_or_cls_name
,
final_patch_func_or_cls
)
for
key
,
value
in
sys
.
modules
.
copy
().
items
():
if
self
.
orig_func_or_cls_name
is
not
None
and
hasattr
(
value
,
self
.
orig_func_or_cls_name
)
\
and
id
(
getattr
(
value
,
self
.
orig_func_or_cls_name
))
==
self
.
orig_func_or_cls_id
:
setattr
(
value
,
self
.
orig_func_or_cls_name
,
final_patch_func_or_cls
)
self
.
is_applied
=
True
@
staticmethod
def
parse_path
(
module_path
,
function_name
,
create_dummy
):
from
importlib.machinery
import
ModuleSpec
modules
=
module_path
.
split
(
'.'
)
for
i
in
range
(
1
,
len
(
modules
)
+
1
):
parent
=
'.'
.
join
(
modules
[:
i
-
1
])
path
=
'.'
.
join
(
modules
[:
i
])
try
:
importlib
.
import_module
(
path
)
except
ModuleNotFoundError
as
e
:
if
not
parent
or
not
hasattr
(
importlib
.
import_module
(
parent
),
modules
[
i
-
1
]):
if
not
create_dummy
:
raise
ModuleNotFoundError
(
e
)
from
e
sys
.
modules
[
path
]
=
types
.
ModuleType
(
path
)
sys
.
modules
[
path
].
__file__
=
'dcu_megatron.dummy_module.py'
sys
.
modules
[
path
].
__spec__
=
ModuleSpec
(
path
,
None
)
if
parent
:
setattr
(
importlib
.
import_module
(
parent
),
modules
[
i
-
1
],
sys
.
modules
[
path
])
else
:
module
=
getattr
(
importlib
.
import_module
(
parent
),
modules
[
i
-
1
])
if
hasattr
(
module
,
function_name
):
return
module
,
getattr
(
module
,
function_name
)
elif
create_dummy
:
return
module
,
dummy_function_wrapper
(
function_name
)
else
:
raise
RuntimeError
(
'no exist {} of {}'
.
format
(
function_name
,
module
))
if
function_name
is
not
None
and
not
hasattr
(
sys
.
modules
[
module_path
],
function_name
):
setattr
(
sys
.
modules
[
module_path
],
function_name
,
None
)
return
sys
.
modules
[
module_path
],
getattr
(
sys
.
modules
[
module_path
],
function_name
)
if
function_name
is
not
None
else
None
class
MegatronPatchesManager
:
patches_info
=
{}
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
def
register_cls_funcs
(
orig_class
,
new_funcs
:
list
=
None
,
create_dummy
=
False
):
if
not
orig_class
.
endswith
(
"."
):
orig_class
+=
"."
for
new_func
in
new_funcs
:
assert
hasattr
(
new_func
,
'__name__'
)
and
not
new_func
.
__name__
.
endswith
((
'wrapper'
,
'decorator'
))
orig_func_name
=
orig_class
+
new_func
.
__name__
MegatronPatchesManager
.
register_patch
(
orig_func_name
,
new_func_or_cls
=
new_func
,
create_dummy
=
create_dummy
)
@
staticmethod
def
apply_patches
():
for
patch
in
MegatronPatchesManager
.
patches_info
.
values
():
patch
.
apply_patch
()
dcu_megatron/core/__pycache__/parallel_state.cpython-310.pyc
0 → 100644
View file @
01bcbb1e
File added
dcu_megatron/core/__pycache__/quantization_utils.cpython-310.pyc
0 → 100644
View file @
01bcbb1e
File added
dcu_megatron/core/datasets/VL_helpers.py
0 → 100644
View file @
01bcbb1e
This diff is collapsed.
Click to expand it.
dcu_megatron/core/datasets/image_processing.py
0 → 100644
View file @
01bcbb1e
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
import
numpy
as
np
import
torch
import
math
import
random
from
PIL
import
Image
,
ImageDraw
from
torchvision
import
transforms
as
T
from
torchvision.transforms
import
Compose
,
RandAugment
,
RandomResizedCrop
,
Resize
,
ToPILImage
# Imagenet's mean and std.
pixel_mean
=
[
123.675
,
116.28
,
103.53
]
pixel_std
=
[
58.395
,
57.12
,
57.375
]
# Reshape for broadcasting.
pixel_mean
=
torch
.
Tensor
(
pixel_mean
).
view
(
-
1
,
1
,
1
)
pixel_std
=
torch
.
Tensor
(
pixel_std
).
view
(
-
1
,
1
,
1
)
def
convert_to_rgb
(
image
):
return
image
.
convert
(
"RGB"
)
def
_transform_train_aug
():
return
Compose
([
ToPILImage
(),
Resize
(
scale
=
random
.
random
()
/
2
+
0.5
),
convert_to_rgb
,
RandAugment
(
2
,
5
,
isPIL
=
True
,
augs
=
[
'Identity'
,
'AutoContrast'
,
'Brightness'
,
'Sharpness'
,
'Equalize'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'Rotate'
]),
])
def
_transform_test
():
return
Compose
([
ToPILImage
(),
convert_to_rgb
,
])
def
standardize_image
(
img
):
"""Standardize image pixel values."""
return
(
torch
.
Tensor
(
np
.
array
(
img
)).
permute
(
2
,
0
,
1
)
-
pixel_mean
)
/
pixel_std
def
get_visual_transform
(
img
,
factor
:
int
=
28
,
min_pixels
:
int
=
56
*
56
,
max_pixels
:
int
=
14
*
14
*
4
*
1280
,
augment
=
False
):
img
=
np
.
array
(
img
)
if
augment
:
visual_transform
=
_transform_train_aug
()
else
:
visual_transform
=
_transform_test
()
img
=
visual_transform
(
img
)
w
,
h
=
img
.
size
h_bar
,
w_bar
=
smart_resize
(
h
,
w
,
factor
,
min_pixels
,
max_pixels
)
img
=
img
.
resize
((
w_bar
,
h_bar
))
# Standardize pixel values.
img
=
standardize_image
(
img
)
imgs
=
[
img
]
return
imgs
# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
28
,
min_pixels
:
int
=
56
*
56
,
max_pixels
:
int
=
14
*
14
*
4
*
1280
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
height
<
factor
or
width
<
factor
:
raise
ValueError
(
f
"height:
{
height
}
or width:
{
width
}
must be larger than factor:
{
factor
}
"
)
elif
max
(
height
,
width
)
/
min
(
height
,
width
)
>
200
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than 200, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
round
(
height
/
factor
)
*
factor
w_bar
=
round
(
width
/
factor
)
*
factor
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
math
.
floor
(
height
/
beta
/
factor
)
*
factor
w_bar
=
math
.
floor
(
width
/
beta
/
factor
)
*
factor
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
math
.
ceil
(
height
*
beta
/
factor
)
*
factor
w_bar
=
math
.
ceil
(
width
*
beta
/
factor
)
*
factor
return
h_bar
,
w_bar
dcu_megatron/core/dist_checkpoint/exchange_utils.py
0 → 100644
View file @
01bcbb1e
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for exchanging data between ranks."""
import
logging
from
collections
import
defaultdict
from
typing
import
Optional
,
Set
import
torch
from
megatron.core.utils
import
get_pg_size
from
megatron.core.dist_checkpointing.dict_utils
import
nested_values
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
,
ShardedTensor
,
is_main_replica
,
ReplicaId
from
megatron.core.dist_checkpointing.utils
import
_sharded_tensor_shard_id
,
_ShardId
from
megatron.core.dist_checkpointing.exchange_utils
import
ShardDistribution
,
_shard_size
,
distribute_shards_to_ranks
logger
=
logging
.
getLogger
(
__name__
)
def
is_main_replica_norm
(
replica_id
:
ReplicaId
):
if
isinstance
(
replica_id
,
int
):
return
replica_id
==
0
return
len
(
replica_id
)
>
0
and
replica_id
[
-
1
]
==
0
def
determine_main_replica_uniform_distribution
(
sharded_state_dict
:
ShardedStateDict
,
parallelization_group
:
torch
.
distributed
.
ProcessGroup
,
ignore_groups
:
bool
=
False
,
)
->
Optional
[
ShardDistribution
]:
"""Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
ignore_groups (bool, optional): whether the distribution defines groups.
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
if
parallelization_group
is
None
:
parallelization_group
=
torch
.
distributed
.
group
.
WORLD
group_size
=
get_pg_size
(
group
=
parallelization_group
)
if
group_size
<=
1
:
return
local_shards
=
list
(
sh_base
for
sh_base
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_base
,
ShardedTensor
)
)
local_shards_no_data
=
[
ten
.
without_data
()
for
ten
in
local_shards
]
all_shards
=
[
None
]
*
get_pg_size
(
group
=
parallelization_group
)
torch
.
distributed
.
all_gather_object
(
all_shards
,
local_shards_no_data
,
group
=
parallelization_group
)
shard_to_ranks
=
defaultdict
(
list
)
shard_to_size
=
{}
shard_to_metadata
=
{}
group_has_main_replica
:
Set
[
_ShardId
]
=
set
()
group_has_non_main_replica
:
Set
[
_ShardId
]
=
set
()
for
rank
,
rank_shards
in
enumerate
(
all_shards
):
for
sh_ten
in
rank_shards
:
shard_id
=
_sharded_tensor_shard_id
(
sh_ten
)
shard_to_ranks
[
shard_id
].
append
(
rank
)
if
shard_id
not
in
shard_to_size
:
shard_to_size
[
shard_id
]
=
_shard_size
(
sh_ten
)
shard_to_metadata
[
shard_id
]
=
sh_ten
if
'norm'
in
shard_id
[
0
]:
if
is_main_replica_norm
(
sh_ten
.
replica_id
):
group_has_main_replica
.
add
(
shard_id
)
else
:
group_has_non_main_replica
.
add
(
shard_id
)
else
:
if
is_main_replica
(
sh_ten
.
replica_id
):
group_has_main_replica
.
add
(
shard_id
)
else
:
group_has_non_main_replica
.
add
(
shard_id
)
# we always include all main replicas, and non-main only if `ignore_groups`
shards_in_this_group
:
Set
[
_ShardId
]
=
group_has_main_replica
if
ignore_groups
:
shards_in_this_group
=
shards_in_this_group
|
group_has_non_main_replica
# cross-parallel-group references are empty if `not ignore_groups`,
# otherwise it's `group_has_non_main_replica - group_has_main_replica`
cross_parallelization_group_loads
=
shards_in_this_group
-
group_has_main_replica
# Filter out shards that don't belong to this group
shard_to_ranks
=
{
k
:
v
for
k
,
v
in
shard_to_ranks
.
items
()
if
k
in
shards_in_this_group
}
shard_to_saving_rank
=
distribute_shards_to_ranks
(
shard_to_ranks
,
shard_to_size
,
len
(
all_shards
),
cross_parallelization_group_loads
)
return
ShardDistribution
(
shard_to_saving_rank
,
shards_in_this_group
,
shard_to_metadata
,
shard_to_ranks
)
dcu_megatron/core/dist_checkpoint/strategies/cached_metadata_filesystem_reader.py
0 → 100644
View file @
01bcbb1e
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import
os
from
typing
import
Union
from
torch.distributed.checkpoint
import
Metadata
from
hyckpt_torch
import
FileSystemReader
class
CachedMetadataFileSystemReader
(
FileSystemReader
):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def
__init__
(
self
,
path
:
Union
[
str
,
os
.
PathLike
])
->
None
:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super
().
__init__
(
path
=
path
)
self
.
_cached_metadata
=
None
dcu_megatron/core/dist_checkpoint/strategies/filesystem_async.py
0 → 100644
View file @
01bcbb1e
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import
logging
from
pathlib
import
Path
from
typing
import
List
,
Tuple
import
torch
from
torch
import
multiprocessing
as
mp
from
hyckpt_torch
import
_write_items
from
megatron.core.dist_checkpointing.strategies.async_utils
import
_disable_gc
from
megatron.core.dist_checkpointing.strategies.filesystem_async
import
_process_memory
WriteBucket
=
Tuple
[
Path
,
str
,
Tuple
[
list
,
list
]]
# represents writes to a single file
@
staticmethod
@
_disable_gc
()
def
write_preloaded_data
(
transform_list
,
local_proc_idx
:
int
,
write_bucket
:
WriteBucket
,
results_queue
:
mp
.
SimpleQueue
,
count_queue
:
mp
.
JoinableQueue
,
use_fsync
:
bool
,
)
->
None
:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the `queue`
"""
logger
=
logging
.
getLogger
(
__name__
)
logger
.
debug
(
f
'
{
local_proc_idx
}
started'
)
mem_before
=
_process_memory
()
rank
=
torch
.
distributed
.
get_rank
()
local_results
=
[]
try
:
local_results
=
_write_items
(
write_bucket
)
'''
for result in local_results:
if hasattr(result.index, 'index'):
from dataclasses import replace
new_index = replace(result.index, index=rank)
new_result = replace(result, index=new_index)
'''
local_output
=
(
local_proc_idx
,
local_results
)
except
Exception
as
e
:
logger
.
debug
(
f
'
{
local_proc_idx
}
failed'
)
local_output
=
(
local_proc_idx
,
e
)
results_queue
.
put
(
local_output
)
# Signal this process is done.
count_queue
.
get
()
count_queue
.
task_done
()
mem_after
=
_process_memory
()
logger
.
debug
(
f
"
{
local_proc_idx
}
consumed:
{
mem_after
-
mem_before
}
,"
f
" before:
{
mem_before
}
, after:
{
mem_after
}
"
)
@
staticmethod
def
preload_tensors
(
write_buckets
:
List
[
WriteBucket
],
non_blocking
=
True
)
->
List
[
WriteBucket
]:
"""Preload tensors in state_dict to host memory through CPU memory
Args:
write_buckets(List): List of `WriteBucket`,
which includes what to be saved in a checkpoint
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result
=
[]
for
bucket
in
write_buckets
:
file_name
,
storage_key
,
(
bytes_data
,
tensor_data
)
=
bucket
tensor_data
=
[
(
item
,
tensor
.
to
(
"cpu"
,
non_blocking
=
False
))
for
item
,
tensor
in
tensor_data
]
result
.
append
((
file_name
,
storage_key
,
(
bytes_data
,
tensor_data
)))
if
non_blocking
:
torch
.
cuda
.
synchronize
()
return
result
dcu_megatron/core/dist_checkpoint/strategies/fully_parallel.py
0 → 100644
View file @
01bcbb1e
import
logging
from
typing
import
Optional
from
megatron.core.dist_checkpointing.exchange_utils
import
(
ShardDistribution
,
determine_main_replica_uniform_distribution
,
)
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.strategies.fully_parallel
import
distribute_main_replicas_with_precomputed_distribution
logger
=
logging
.
getLogger
(
__name__
)
class
FullyParallelLoadStrategyWrapper
():
def
apply_loading_parallelization
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
Optional
[
ShardDistribution
]:
"""Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
ShardDistribution (optional): the computed loading distribution
"""
if
self
.
do_cache_distribution
and
self
.
cached_distribution
is
not
None
:
logger
.
debug
(
f
'Apply *cached* load parallelization'
)
precomputed_distribution
=
self
.
cached_distribution
else
:
logger
.
debug
(
f
'Apply load parallelization'
)
precomputed_distribution
=
determine_main_replica_uniform_distribution
(
sharded_state_dict
,
self
.
parallelization_group
)
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
precomputed_distribution
)
if
self
.
do_cache_distribution
:
self
.
cached_distribution
=
precomputed_distribution
return
precomputed_distribution
dcu_megatron/core/dist_checkpoint/strategies/torch.py
0 → 100644
View file @
01bcbb1e
import
io
import
os
import
pickle
import
warnings
from
collections
import
ChainMap
,
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
product
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
torch
from
packaging.version
import
Version
as
PkgVersion
from
torch.distributed
import
checkpoint
from
torch.distributed._shard.metadata
import
ShardMetadata
from
torch.distributed._shard.sharded_tensor
import
Shard
from
torch.distributed._shard.sharded_tensor
import
ShardedTensor
as
TorchShardedTensor
from
torch.distributed._shard.sharded_tensor
import
ShardedTensorMetadata
,
TensorProperties
from
torch.distributed.checkpoint
import
(
BytesStorageMetadata
,
DefaultLoadPlanner
,
DefaultSavePlanner
,
#FileSystemReader,
FileSystemWriter
,
LoadPlan
,
Metadata
,
ReadItem
,
SavePlan
,
TensorStorageMetadata
,
WriteItem
,
)
from
hyckpt_torch
import
FileSystemReader
from
torch.distributed.checkpoint._nested_dict
import
FLATTEN_MAPPING
,
unflatten_state_dict
from
torch.distributed.checkpoint._traverse
import
OBJ_PATH
,
traverse_state_dict
from
torch.distributed.checkpoint.metadata
import
Metadata
from
torch.distributed.checkpoint.planner_helpers
import
_create_write_items
from
megatron.core.utils
import
get_torch_version
,
is_torch_min_version
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
from
megatron.core.dist_checkpointing.dict_utils
import
nested_values
from
megatron.core.dist_checkpointing.mapping
import
(
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
StateDict
,
is_main_replica
,
)
from
megatron.core.dist_checkpointing.strategies.async_utils
import
AsyncRequest
from
megatron.core.dist_checkpointing.strategies.base
import
(
AsyncSaveShardedStrategy
,
LoadShardedStrategy
,
StrategyAction
,
register_default_strategy
,
)
from
megatron.core.dist_checkpointing.strategies.cached_metadata_filesystem_reader
import
CachedMetadataFileSystemReader
from
megatron.core.dist_checkpointing.strategies.filesystem_async
import
FileSystemWriterAsync
from
megatron.core.dist_checkpointing.strategies.resharding
import
(
TensorReformulationMetadata
,
apply_nd_flattened_tensors_reformulation
,
is_nd_flattened_tensor
,
nd_flattened_tensor_reformulated_global_shape
,
restore_nd_flattened_tensors_formulation
,
)
from
megatron.core.dist_checkpointing.strategies.state_dict_saver
import
save_state_dict_async_finalize
,
save_state_dict_async_plan
from
megatron.core.dist_checkpointing.strategies.torch
import
(
_replace_state_dict_keys_with_sharded_keys
,
mcore_to_pyt_state_dict
,
MCoreLoadPlanner
,
_replace_sharded_keys_with_state_dict_keys
,
_restore_dict_types
,
_unwrap_pyt_sharded_tensor
)
def
get_reformulation_metadata
(
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
Dict
[
str
,
TensorReformulationMetadata
]:
"""Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory
Returns:
Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every
N-D flattened tensor from the sharded_state_dict to its original global shape
as stored in `mcore_data` in the checkpoint.
"""
ckpt_metadata
=
FileSystemReader
(
checkpoint_dir
).
read_metadata
()
reformulation_metadata
=
{}
for
sh_ten
in
nested_values
(
sharded_state_dict
):
if
not
is_nd_flattened_tensor
(
sh_ten
):
continue
try
:
ckpt_global_shape
=
ckpt_metadata
.
mcore_data
[
sh_ten
.
key
][
'nd_reformulated_orig_global_shape'
]
except
KeyError
as
e
:
if
len
(
sh_ten
.
global_shape
)
==
1
:
warnings
.
warn
(
f
'Legacy checkpoint format detected for 1-D flattened tensor
{
sh_ten
}
. '
'Skip metadata reformulation.'
)
continue
raise
CheckpointingException
(
f
'Cannot find global shape metadata for N-D flattened tensor
{
sh_ten
}
'
f
'in checkpoint metadata:
{
ckpt_metadata
.
mcore_data
}
'
)
from
e
reformulation_metadata
[
sh_ten
.
key
]
=
TensorReformulationMetadata
(
ckpt_global_shape
,
ckpt_metadata
.
state_dict_metadata
[
sh_ten
.
key
].
size
)
return
reformulation_metadata
class
TorchDistLoadShardedStrategy
(
LoadShardedStrategy
):
def
__init__
(
self
):
self
.
cached_global_metadata
:
Optional
[
Metadata
]
=
None
super
().
__init__
()
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
StateDict
:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata
=
get_reformulation_metadata
(
sharded_state_dict
,
checkpoint_dir
)
sharded_state_dict
,
formulation_restore_data
=
apply_nd_flattened_tensors_reformulation
(
sharded_state_dict
,
reformulation_metadata
)
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors
=
False
for
sh_ten
in
nested_values
(
sharded_state_dict
):
if
is_nd_flattened_tensor
(
sh_ten
)
and
sh_ten
.
key
not
in
reformulation_metadata
:
has_legacy_1d_flattened_tensors
=
True
break
flexible_shape_sharded_tensors
=
[
sh_ten
for
sh_ten
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_ten
,
ShardedTensor
)
and
not
sh_ten
.
allow_shape_mismatch
]
allow_shape_mismatch_sharded_tensors
=
{
sh_ten
.
key
:
sh_ten
for
sh_ten
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_ten
,
ShardedTensor
)
and
sh_ten
.
allow_shape_mismatch
}
orig_sharded_state_dict
=
sharded_state_dict
# MCore state dict to PyT Distributed compatible
(
sharded_state_dict
,
flat_mapping
,
rename_mapping
)
=
(
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
)
)
pyt_state_dict
=
mcore_to_pyt_state_dict
(
sharded_state_dict
,
True
,
load_legacy_1d_flatten_tensors
=
has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
fsr
=
CachedMetadataFileSystemReader
(
checkpoint_dir
)
checkpoint
.
load_state_dict
(
pyt_state_dict
,
fsr
,
planner
=
MCoreLoadPlanner
(
shapes_validation_sharded_tensors
=
flexible_shape_sharded_tensors
,
allow_shape_mismatch_sharded_tensors
=
allow_shape_mismatch_sharded_tensors
,
),
)
self
.
cached_global_metadata
=
(
fsr
.
read_metadata
()
)
# no storage interaction thanks to caching
pyt_state_dict
=
cast
(
Dict
[
str
,
Union
[
TorchShardedTensor
,
List
[
io
.
BytesIO
]]],
pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict
=
{
k
:
v
if
not
isinstance
(
v
,
TorchShardedTensor
)
else
_unwrap_pyt_sharded_tensor
(
v
)
for
k
,
v
in
pyt_state_dict
.
items
()
}
mcore_state_dict
=
_replace_sharded_keys_with_state_dict_keys
(
mcore_state_dict
,
flat_mapping
,
rename_mapping
)
_restore_dict_types
(
mcore_state_dict
,
orig_sharded_state_dict
)
# Apply N-D tensors resharding postprocessing
mcore_state_dict
=
restore_nd_flattened_tensors_formulation
(
mcore_state_dict
,
formulation_restore_data
)
return
mcore_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
,
metadata
:
Metadata
=
None
):
"""Uses tensors metadata stored in the metadata file."""
if
metadata
is
None
:
fs_reader
=
FileSystemReader
(
checkpoint_dir
)
metadata
=
fs_reader
.
read_metadata
()
mcore_data
=
getattr
(
metadata
,
'mcore_data'
,
{})
sharded_metadata
=
{}
for
k
,
tp
in
metadata
.
state_dict_metadata
.
items
():
if
not
isinstance
(
tp
,
TensorStorageMetadata
):
continue
# load only tensors
nd_orig_global_shape
=
mcore_data
.
get
(
k
,
{}).
get
(
'nd_reformulated_orig_global_shape'
)
if
nd_orig_global_shape
is
None
:
# Regular tensor
sharded_metadata
[
k
]
=
ShardedTensor
.
from_rank_offsets
(
k
,
torch
.
empty
(
tp
.
size
,
**
tp
.
properties
.
__dict__
,
device
=
'meta'
)
).
without_data
()
else
:
# N-D flattened tensor
unflat_ten
=
torch
.
empty
(
nd_orig_global_shape
,
**
tp
.
properties
.
__dict__
,
device
=
'meta'
)
flat_ten
=
unflat_ten
.
flatten
()
sharded_metadata
[
k
]
=
ShardedTensor
.
from_rank_offsets_flat
(
k
,
flat_ten
,
unflat_ten
.
shape
,
flattened_range
=
slice
(
0
,
unflat_ten
.
numel
()),
# whole slice
).
without_data
()
return
sharded_metadata
def
load_sharded_metadata
(
self
,
checkpoint_dir
:
Path
)
->
ShardedStateDict
:
"""Uses tensors and objects metadata stored in the metadata file."""
fs_reader
=
FileSystemReader
(
checkpoint_dir
)
metadata
=
fs_reader
.
read_metadata
()
sharded_metadata
=
{}
for
metadata_key
,
storage_metadata
in
metadata
.
state_dict_metadata
.
items
():
if
not
isinstance
(
storage_metadata
,
BytesStorageMetadata
):
continue
sh_obj
=
ShardedObject
.
empty_from_unique_key
(
metadata_key
)
sharded_metadata
[
sh_obj
.
unique_key
]
=
sh_obj
sharded_metadata
.
update
(
self
.
load_tensors_metadata
(
checkpoint_dir
,
metadata
))
return
sharded_metadata
def
remove_sharded_tensors
(
self
,
checkpoint_dir
:
str
,
key_prefix
:
str
):
"""Removes checkpoint files whose keys have the given prefix.
Performs the following steps:
1. checks whether there are files that start with the key_prefix
2. loads metadata
3. removes all entries from the metadata that start with the key_prefix
4. resaves the new metadata and removes the old metadata
5. removes the relevant files
"""
assert
is_torch_min_version
(
"2.3.0"
),
f
'torch >= 2.3.0 is required for remove_sharded_tensors'
distckpt_files
=
[
f
for
f
in
os
.
listdir
(
checkpoint_dir
)
if
f
.
endswith
(
"distcp"
)]
files_to_remove
=
[
f
for
f
in
distckpt_files
if
f
.
startswith
(
key_prefix
)]
if
not
files_to_remove
:
warnings
.
warn
(
f
'There are no files in
{
checkpoint_dir
}
that begin with "
{
key_prefix
}
".'
f
' Skipping removal.'
)
return
fs_reader
=
FileSystemReader
(
checkpoint_dir
)
original_metadata
=
fs_reader
.
read_metadata
()
new_state_dict_metadata
=
{}
new_planner_data
=
{}
new_storage_data
=
{}
for
k
in
original_metadata
.
state_dict_metadata
.
keys
():
if
k
.
startswith
(
key_prefix
):
continue
new_state_dict_metadata
[
k
]
=
original_metadata
.
state_dict_metadata
[
k
]
for
k
in
original_metadata
.
planner_data
.
keys
():
if
k
.
startswith
(
key_prefix
):
continue
new_planner_data
[
k
]
=
original_metadata
.
planner_data
[
k
]
for
k
in
original_metadata
.
storage_data
.
keys
():
if
k
.
fqn
.
startswith
(
key_prefix
):
continue
new_storage_data
[
k
]
=
original_metadata
.
storage_data
[
k
]
metadata
=
Metadata
(
state_dict_metadata
=
new_state_dict_metadata
,
planner_data
=
new_planner_data
,
storage_data
=
new_storage_data
,
)
fs_writer
=
FileSystemWriter
(
checkpoint_dir
)
metadata_filename
=
cast
(
Path
,
fs_writer
.
fs
.
concat_path
(
fs_writer
.
path
,
_metadata_fn
))
tmp_path
=
cast
(
metadata_filename
,
fs_writer
.
fs
.
concat_path
(
fs_writer
.
path
,
f
"
{
_metadata_fn
}
.tmp"
)
)
old_path
=
cast
(
metadata_filename
,
fs_writer
.
fs
.
concat_path
(
fs_writer
.
path
,
f
"
{
_metadata_fn
}
.bck"
)
)
## save the new metadata
with
fs_writer
.
fs
.
create_stream
(
tmp_path
,
"wb"
)
as
metadata_file
:
pickle
.
dump
(
metadata
,
metadata_file
)
try
:
os
.
fsync
(
metadata_file
.
fileno
())
except
AttributeError
:
os
.
sync
()
## move the old metadata
fs_writer
.
fs
.
rename
(
fs_writer
.
metadata_path
,
old_path
)
try
:
## rename the new metadata
fs_writer
.
fs
.
rename
(
tmp_path
,
fs_writer
.
metadata_path
)
## finally, remove the files we want to drop
for
f
in
files_to_remove
:
fs_writer
.
fs
.
rm_file
(
checkpoint_dir
/
f
)
except
Exception
as
e
:
fs_writer
.
fs
.
rename
(
old_path
,
fs_writer
.
metadata_path
)
raise
e
else
:
fs_writer
.
fs
.
rm_file
(
old_path
)
def
can_handle_sharded_objects
(
self
):
return
True
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
\ No newline at end of file
dcu_megatron/core/dist_checkpoint/validation.py
0 → 100644
View file @
01bcbb1e
import
logging
import
numpy
as
np
import
torch
from
megatron.core.dist_checkpointing
import
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
from
megatron.core.dist_checkpointing.mapping
import
is_main_replica
logger
=
logging
.
getLogger
(
__name__
)
def
_compute_shards_access
(
rank_sharding
):
shard_access_cnt
=
torch
.
zeros
(
rank_sharding
[
0
][
1
].
axis_fragmentations
,
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
rank
,
sharding
in
rank_sharding
:
if
is_main_replica
(
sharding
.
replica_id
):
if
'norm'
in
sharding
.
key
:
shard_access_cnt
[
sharding
.
local_chunk_offset_in_global
()]
=
1
else
:
shard_access_cnt
[
sharding
.
local_chunk_offset_in_global
()]
+=
1
return
shard_access_cnt
def
_validate_sharding_for_key_flattened
(
tensors_by_shard
):
all_slices
=
[]
local_shape
=
tensors_by_shard
[
0
].
local_shape
for
sharding
in
tensors_by_shard
:
assert
sharding
.
local_shape
==
local_shape
sharding
:
ShardedTensor
if
not
is_main_replica
(
sharding
.
replica_id
):
continue
if
all_slices
and
'norm'
in
sharding
.
key
:
continue
all_slices
.
append
((
sharding
.
flattened_range
.
start
,
sharding
.
flattened_range
.
stop
))
starts
,
stops
=
map
(
np
.
asarray
,
zip
(
*
sorted
(
all_slices
)))
expected_size
=
np
.
product
(
local_shape
)
if
starts
[
0
]
!=
0
or
stops
[
-
1
]
!=
expected_size
or
not
np
.
all
(
starts
[
1
:]
==
stops
[:
-
1
]):
raise
CheckpointingException
(
f
"Flattened ranges dont cover the whole shard
{
tensors_by_shard
[
0
]
}
of size
{
expected_size
}
. Ranges:
{
(
starts
,
stops
)
}
"
)
dcu_megatron/core/distributed/__pycache__/data_parallel_base.cpython-310.pyc
0 → 100644
View file @
01bcbb1e
File added
dcu_megatron/core/distributed/__pycache__/param_and_grad_buffer.cpython-310.pyc
0 → 100644
View file @
01bcbb1e
File added
dcu_megatron/core/distributed/__pycache__/power_sgd.cpython-310.pyc
0 → 100644
View file @
01bcbb1e
File added
dcu_megatron/core/distributed/data_parallel_base.py
0 → 100644
View file @
01bcbb1e
class
_BaseDataParallel
():
def
backward_dw
(
self
,
*
inputs
,
**
kwargs
):
"""
Calls the wrapped module's backward_dw() method.
"""
return
self
.
module
.
backward_dw
(
*
inputs
,
**
kwargs
)
\ No newline at end of file
dcu_megatron/core/distributed/distributed_data_parallel.py
0 → 100644
View file @
01bcbb1e
import
torch
from
megatron.training
import
get_args
from
megatron.core.transformer.cuda_graphs
import
is_graph_capturing
class
DistributedDataParallel
():
def
_make_backward_post_hook
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def
hook
(
*
unused
):
if
is_graph_capturing
():
return
if
param
in
self
.
param_to_bucket_group
:
assert
param
.
requires_grad
if
self
.
ddp_config
.
overlap_grad_reduce
:
# support dualpipev
if
not
get_args
().
gradient_accumulation_fusion
or
not
get_args
().
delay_wgrad_compute
:
assert
(
param
.
grad
is
not
None
),
'param.grad being None is not safe when overlap_grad_reduce is True'
if
param
.
grad
is
not
None
and
(
not
param
.
grad_added_to_main_grad
or
getattr
(
param
,
'zero_out_wgrad'
,
False
)
):
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
grad
=
None
if
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
param_to_bucket_group
[
param
].
register_grad_ready
(
param
)
return
hook
dcu_megatron/core/distributed/finalize_model_grads.py
0 → 100755
View file @
01bcbb1e
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
Optional
import
torch
try
:
from
torch.distributed._tensor
import
DTensor
,
distribute_tensor
HAVE_DTENSOR
=
True
except
ImportError
:
HAVE_DTENSOR
=
False
from
megatron.core
import
mpu
from
megatron.core
import
parallel_state
from
megatron.core.utils
import
get_model_config
from
megatron.training.global_vars
import
get_args
from
...training.edgc_utils
import
Utils
from
megatron.core.distributed.finalize_model_grads
import
(
_allreduce_conditional_embedding_grads
,
_allreduce_non_tensor_model_parallel_grads
,
_allreduce_word_embedding_grads
,
_allreduce_position_embedding_grads
,
reset_model_temporary_tensors
,
_update_router_expert_bias
)
def
finalize_model_grads
(
model
:
List
[
torch
.
nn
.
Module
],
num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
pg_collection
:
Optional
[
ProcessGroupCollection
]
=
None
,
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
args
=
get_args
()
config
=
get_model_config
(
model
[
0
])
if
pg_collection
is
not
None
:
assert
hasattr
(
pg_collection
,
'tp'
)
assert
hasattr
(
pg_collection
,
'pp'
)
assert
hasattr
(
pg_collection
,
'embd'
),
(
"pg_collection must have a embd. In previous version, it is used default "
"`parallel_state.default_embedding_ranks` to create the process group."
" If you are using the default process group, please use"
" `parallel_state.get_embedding_group()` "
"If you don't need embd_group, you need to explicitly set it to None."
)
assert
hasattr
(
pg_collection
,
'pos_embd'
),
(
"pg_collection must have a pos_embd. In previous version, it is used default "
"`parallel_state.default_position_embedding_ranks` to create the process group."
" If you are using the default process group, please use "
" `parallel_state.get_position_embedding_group()` "
"If you don't need pos_embd_group, you need to explicitly set it to None."
)
assert
hasattr
(
pg_collection
,
'dp_cp'
)
tp_group
=
pg_collection
.
tp
pp_group
=
pg_collection
.
pp
embd_group
=
pg_collection
.
embd
pos_emb_group
=
pg_collection
.
pos_embd
dp_cp_group
=
pg_collection
.
dp_cp
else
:
tp_group
=
parallel_state
.
get_tensor_model_parallel_group
()
pp_group
=
parallel_state
.
get_pipeline_model_parallel_group
()
embd_group
=
parallel_state
.
get_embedding_group
(
check_initialized
=
False
)
pos_emb_group
=
parallel_state
.
get_position_embedding_group
(
check_initialized
=
False
)
dp_cp_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
)
# All-reduce / reduce-scatter across DP replicas.
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
def
_handle_all_reduce_time_start
(
args
,
config
):
if
args
.
all_reduce_time
:
config
.
timers
(
'DP_time'
,
log_level
=
0
).
start
()
def
_handle_all_reduce_time_end
(
args
,
config
):
if
args
.
all_reduce_time
:
config
.
timers
(
'DP_time'
).
stop
()
def
_update_gradient_compression_state
(
args
):
if
args
.
max_rank
is
None
:
if
args
.
is_loading_checkpoint
:
if
args
.
curr_iteration
>=
(
args
.
latest_iteration
+
12
):
args
.
grad_comp_enabled
=
True
else
:
if
args
.
curr_iteration
>=
12
:
args
.
grad_comp_enabled
=
True
else
:
if
args
.
curr_iteration
>
args
.
warm_up_train_iter
:
if
args
.
begin_max_rank
:
args
.
grad_comp_enabled
=
not
(
args
.
is_loading_checkpoint
and
(
len
(
Utils
.
mapped_rank
)
==
0
or
Utils
.
mapped_rank
[
-
1
]
is
None
))
elif
(
args
.
curr_iteration
%
args
.
rank_adjust_window_size
==
1
)
and
(
args
.
curr_iteration
!=
(
args
.
latest_iteration
+
1
)):
args
.
grad_comp_enabled
=
True
if
not
mpu
.
is_pipeline_first_stage
():
_update_mapped_rank_based_on_final_rank
(
args
)
elif
args
.
begin_warm_up
:
args
.
grad_comp_enabled
=
False
args
.
begin_warm_up
=
False
args
.
grad_comp
=
args
.
grad_comp_enabled
def
_update_mapped_rank_based_on_final_rank
(
args
):
if
len
(
Utils
.
mapped_rank
)
>=
2
:
if
args
.
final_rank
is
None
:
args
.
grad_comp_enabled
=
False
elif
args
.
final_rank
!=
Utils
.
mapped_rank
[
-
2
]:
if
args
.
final_rank
is
not
None
:
args
.
mapped_rank
=
args
.
final_rank
else
:
args
.
grad_comp_enabled
=
False
else
:
args
.
mapped_rank
=
args
.
final_rank
def
_get_find_rank
(
args
):
"""Helper to determine rank when finding rank upper limit."""
if
args
.
mapped_rank
is
not
None
:
return
int
(
args
.
mapped_rank
)
if
args
.
is_loading_checkpoint
:
return
int
(
Utils
.
mapped_rank
[
-
1
]
if
Utils
.
mapped_rank
else
args
.
max_rank
)
return
int
(
args
.
max_rank
)
def
_get_adaptive_rank
(
args
):
"""Helper to determine rank during adaptive compression."""
if
args
.
is_loading_checkpoint
:
delta_iter
=
args
.
curr_iteration
-
args
.
latest_iteration
else
:
delta_iter
=
args
.
curr_iteration
return
2
**
int
((
delta_iter
-
9
)
/
3
)
def
compressor_update
(
args
):
if
not
args
.
enable_dynamic_grad_comp
or
not
args
.
grad_comp
:
args
.
compressor
=
None
return
if
args
.
fp16
:
compression_dtype
=
torch
.
float16
elif
args
.
bf16
:
compression_dtype
=
torch
.
bfloat16
else
:
compression_dtype
=
torch
.
float32
rank
=
_get_find_rank
(
args
)
if
args
.
find_rank_upper_limit
else
_get_adaptive_rank
(
args
)
if
args
.
pre_rank
is
not
None
:
if
args
.
pre_rank
==
rank
:
args
.
compressor
.
begin_iteration
(
args
.
curr_iteration
)
return
args
.
pre_rank
=
rank
from
.power_sgd
import
PowerSGDCompressor
args
.
compressor
=
PowerSGDCompressor
(
ef_layout_manager
=
args
.
ef_manager
,
rank
=
rank
,
compression_dtype
=
compression_dtype
)
args
.
compressor
.
begin_iteration
(
args
.
curr_iteration
)
if
args
.
enable_dynamic_grad_comp
and
not
args
.
overlap_grad_reduce
:
_handle_all_reduce_time_start
(
args
,
config
)
for
model_chunk
in
model
:
if
args
.
enable_dynamic_grad_comp
:
_update_gradient_compression_state
(
args
)
compressor_update
(
args
)
model_chunk
.
finish_grad_sync
()
if
args
.
enable_dynamic_grad_comp
:
if
args
.
begin_max_rank
:
args
.
begin_max_rank
=
False
if
not
args
.
overlap_grad_reduce
:
_handle_all_reduce_time_end
(
args
,
config
)
if
args
.
enable_dynamic_grad_comp
:
if
args
.
all_reduce_time
:
args
.
params_all_reduce_time
=
config
.
timers
(
'DP_time'
).
elapsed
(
reset
=
True
)
*
1000.0
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
).
stop
()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if
config
.
timers
is
not
None
:
config
.
timers
(
'conditional-embedder-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_conditional_embedding_grads
(
model
,
config
,
pp_group
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'conditional-embedder-grads-all-reduce'
).
stop
()
# All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules.
if
config
.
timers
is
not
None
:
config
.
timers
(
'non-tensor-parallel-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_non_tensor_model_parallel_grads
(
model
,
config
,
tp_group
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'non-tensor-parallel-grads-all-reduce'
).
stop
()
# All-reduce embedding grads (for pipeline parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_word_embedding_grads
(
model
,
config
,
embd_group
,
pp_group
)
_allreduce_position_embedding_grads
(
model
,
config
,
pos_emb_group
,
pp_group
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
).
stop
()
if
config
.
moe_router_enable_expert_bias
:
_update_router_expert_bias
(
model
,
config
)
reset_model_temporary_tensors
(
config
,
model
)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if
num_tokens
is
not
None
:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
assert
not
isinstance
(
pp_group
,
list
)
last_rank
=
get_pp_last_rank
(
pp_group
)
torch
.
distributed
.
broadcast
(
num_tokens
,
src
=
last_rank
,
group
=
pp_group
)
# all-reduce across DP ranks.
torch
.
distributed
.
all_reduce
(
num_tokens
,
group
=
dp_cp_group
)
for
model_chunk
in
model
:
if
num_tokens
>
0
:
scaling
=
1.0
/
num_tokens
model_chunk
.
scale_gradients
(
scaling
)
Prev
1
2
3
4
5
6
7
8
9
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment