Commit 01bcbb1e authored by lim's avatar lim
Browse files

Initial commit

parent 187361d1
Pipeline #3395 canceled with stages
from argparse import ArgumentParser
from ..feature import AbstractFeature
class ParallelLinearFeature(AbstractFeature):
def __init__(self):
super().__init__('parallel-linear-impl')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--parallel-linear-impl', type=str,
default=None,
choices=['flux'],
help='Specify the method to replace ColumnParallelLinear/RowParallelLinear')
group.add_argument('--save-flux-gather-input', action='store_true', default=False,
help='use gathered input of AGKernel for wgrad computation')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel')
group.add_argument('--disable-bw-flux-gemmrs-op', action='store_false', default=True, dest='enable_bw_flux_gemmrs_op',
help='Do not use flux.GemmRS in backward pass')
def validate_args(self, args):
if args.parallel_linear_impl == "flux" and args.transformer_impl != 'transformer_engine':
raise AssertionError('flux is only supported with transformer_engine implementation')
def register_patches(self, patch_manager, args):
# flux
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec
if args.parallel_linear_impl == 'flux':
patch_manager.register_patch("megatron.core.extensions.transformer_engine.TEColumnParallelLinear",
FluxColumnParallelLinear)
patch_manager.register_patch("megatron.core.extensions.transformer_engine.TERowParallelLinear",
FluxRowParallelLinear)
patch_manager.register_patch("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
This diff is collapsed.
import importlib
import sys
import types
def get_func_name(func):
if isinstance(func, str):
return func
return '.'.join((func.__module__, func.__qualname__))
def dummy_function_wrapper(func_name):
def dummy_function(*args, **kwargs):
raise RuntimeError('function {} no exist'.format(func_name))
return dummy_function
class Patch:
def __init__(self, orig_func_or_cls_name, new_func_or_cls, create_dummy, apply_wrapper=False, remove_origin_wrappers=False):
split_name = orig_func_or_cls_name.rsplit('.', 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_or_cls_name = orig_func_or_cls_name, None
else:
self.orig_module_name, self.orig_func_or_cls_name = split_name
self.orig_module = None
self.orig_func_or_cls = None
self.patch_func_or_cls = None
self.wrappers = []
self.remove_origin_wrappers = False
if (
new_func_or_cls is None
and not remove_origin_wrappers
):
new_func_or_cls = dummy_function_wrapper(orig_func_or_cls_name)
self.set_patch_func(new_func_or_cls, apply_wrapper=apply_wrapper, remove_origin_wrappers=remove_origin_wrappers)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_or_cls_id(self):
return id(self.orig_func_or_cls)
@property
def patch_func_id(self):
return id(self.patch_func_or_cls)
@staticmethod
def remove_wrappers(module, func_name, func):
while True:
if (
module.__dict__
and func_name in module.__dict__
and isinstance(module.__dict__[func_name], (staticmethod, classmethod))
):
func = module.__dict__[func_name].__func__
if hasattr(func, '__wrapped__') and func.__wrapped__ is not None:
func = func.__wrapped__
elif hasattr(func, '__closure__') and func.__closure__ is not None:
func = func.__closure__[0].cell_contents
else:
break
return func
def set_patch_func(self, new_func_or_cls=None, force_patch=False, apply_wrapper=False, remove_origin_wrappers=False):
if remove_origin_wrappers:
self.remove_origin_wrappers = True
else:
assert new_func_or_cls is not None
if new_func_or_cls is None:
return
if (
apply_wrapper
or (hasattr(new_func_or_cls, '__name__') and new_func_or_cls.__name__.endswith(('wrapper', 'decorator')))
):
for wrapper in self.wrappers:
if id(wrapper) == id(new_func_or_cls):
raise RuntimeError(f"wrapper {getattr(new_func_or_cls, '__name__')} has already been applied")
self.wrappers.append(new_func_or_cls)
else:
if (
self.patch_func_or_cls
and not force_patch
and id(new_func_or_cls) != id(self.patch_func_or_cls)
):
raise RuntimeError('the patch of {} exist !'.format(self.orig_func_or_cls_name))
self.patch_func_or_cls = new_func_or_cls
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
self.orig_module, self.orig_func_or_cls = Patch.parse_path(self.orig_module_name, self.orig_func_or_cls_name, self.create_dummy)
final_patch_func_or_cls = self.orig_func_or_cls
if self.patch_func_or_cls is not None:
final_patch_func_or_cls = self.patch_func_or_cls
# remove original wrappers
if self.remove_origin_wrappers:
final_patch_func_or_cls = self.remove_wrappers(self.orig_module, self.orig_func_or_cls_name, final_patch_func_or_cls)
# add new wrappers
for wrapper in self.wrappers:
final_patch_func_or_cls = wrapper(final_patch_func_or_cls)
if self.orig_func_or_cls_name is not None:
setattr(self.orig_module, self.orig_func_or_cls_name, final_patch_func_or_cls)
for key, value in sys.modules.copy().items():
if self.orig_func_or_cls_name is not None and hasattr(value, self.orig_func_or_cls_name) \
and id(getattr(value, self.orig_func_or_cls_name)) == self.orig_func_or_cls_id:
setattr(value, self.orig_func_or_cls_name, final_patch_func_or_cls)
self.is_applied = True
@staticmethod
def parse_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
modules = module_path.split('.')
for i in range(1, len(modules) + 1):
parent = '.'.join(modules[:i - 1])
path = '.'.join(modules[:i])
try:
importlib.import_module(path)
except ModuleNotFoundError as e:
if not parent or not hasattr(importlib.import_module(parent), modules[i - 1]):
if not create_dummy:
raise ModuleNotFoundError(e) from e
sys.modules[path] = types.ModuleType(path)
sys.modules[path].__file__ = 'dcu_megatron.dummy_module.py'
sys.modules[path].__spec__ = ModuleSpec(path, None)
if parent:
setattr(importlib.import_module(parent), modules[i - 1], sys.modules[path])
else:
module = getattr(importlib.import_module(parent), modules[i - 1])
if hasattr(module, function_name):
return module, getattr(module, function_name)
elif create_dummy:
return module, dummy_function_wrapper(function_name)
else:
raise RuntimeError('no exist {} of {}'.format(function_name, module))
if function_name is not None and not hasattr(sys.modules[module_path], function_name):
setattr(sys.modules[module_path], function_name, None)
return sys.modules[module_path], getattr(sys.modules[module_path], function_name) if function_name is not None else None
class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod
def register_cls_funcs(orig_class, new_funcs: list = None, create_dummy=False):
if not orig_class.endswith("."):
orig_class += "."
for new_func in new_funcs:
assert hasattr(new_func, '__name__') and not new_func.__name__.endswith(('wrapper', 'decorator'))
orig_func_name = orig_class + new_func.__name__
MegatronPatchesManager.register_patch(orig_func_name, new_func_or_cls=new_func, create_dummy=create_dummy)
@staticmethod
def apply_patches():
for patch in MegatronPatchesManager.patches_info.values():
patch.apply_patch()
This diff is collapsed.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE.
import numpy as np
import torch
import math
import random
from PIL import Image, ImageDraw
from torchvision import transforms as T
from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage
# Imagenet's mean and std.
pixel_mean = [123.675, 116.28, 103.53]
pixel_std = [58.395, 57.12, 57.375]
# Reshape for broadcasting.
pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
def convert_to_rgb(image):
return image.convert("RGB")
def _transform_train_aug():
return Compose([
ToPILImage(),
Resize(scale=random.random() / 2 + 0.5),
convert_to_rgb,
RandAugment(2, 5, isPIL=True, augs=['Identity', 'AutoContrast', 'Brightness', 'Sharpness', 'Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
])
def _transform_test():
return Compose([
ToPILImage(),
convert_to_rgb,
])
def standardize_image(img):
"""Standardize image pixel values."""
return (torch.Tensor(np.array(img)).permute(2, 0, 1) - pixel_mean) / pixel_std
def get_visual_transform(
img,
factor: int = 28,
min_pixels: int = 56 * 56,
max_pixels: int = 14 * 14 * 4 * 1280,
augment=False
):
img = np.array(img)
if augment:
visual_transform = _transform_train_aug()
else:
visual_transform = _transform_test()
img = visual_transform(img)
w, h = img.size
h_bar, w_bar = smart_resize(h, w, factor, min_pixels, max_pixels)
img = img.resize((w_bar, h_bar))
# Standardize pixel values.
img = standardize_image(img)
imgs = [img]
return imgs
# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
def smart_resize(
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if height < factor or width < factor:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for exchanging data between ranks."""
import logging
from collections import defaultdict
from typing import Optional, Set
import torch
from megatron.core.utils import get_pg_size
from megatron.core.dist_checkpointing.dict_utils import nested_values
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, ShardedTensor, is_main_replica, ReplicaId
from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId
from megatron.core.dist_checkpointing.exchange_utils import ShardDistribution, _shard_size, distribute_shards_to_ranks
logger = logging.getLogger(__name__)
def is_main_replica_norm(replica_id: ReplicaId):
if isinstance(replica_id, int):
return replica_id == 0
return len(replica_id) > 0 and replica_id[-1] == 0
def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
) -> Optional[ShardDistribution]:
"""Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
ignore_groups (bool, optional): whether the distribution defines groups.
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
if parallelization_group is None:
parallelization_group = torch.distributed.group.WORLD
group_size = get_pg_size(group=parallelization_group)
if group_size <= 1:
return
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
local_shards_no_data = [ten.without_data() for ten in local_shards]
all_shards = [None] * get_pg_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_shards, local_shards_no_data, group=parallelization_group
)
shard_to_ranks = defaultdict(list)
shard_to_size = {}
shard_to_metadata = {}
group_has_main_replica: Set[_ShardId] = set()
group_has_non_main_replica: Set[_ShardId] = set()
for rank, rank_shards in enumerate(all_shards):
for sh_ten in rank_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
shard_to_ranks[shard_id].append(rank)
if shard_id not in shard_to_size:
shard_to_size[shard_id] = _shard_size(sh_ten)
shard_to_metadata[shard_id] = sh_ten
if 'norm' in shard_id[0]:
if is_main_replica_norm(sh_ten.replica_id):
group_has_main_replica.add(shard_id)
else:
group_has_non_main_replica.add(shard_id)
else:
if is_main_replica(sh_ten.replica_id):
group_has_main_replica.add(shard_id)
else:
group_has_non_main_replica.add(shard_id)
# we always include all main replicas, and non-main only if `ignore_groups`
shards_in_this_group: Set[_ShardId] = group_has_main_replica
if ignore_groups:
shards_in_this_group = shards_in_this_group | group_has_non_main_replica
# cross-parallel-group references are empty if `not ignore_groups`,
# otherwise it's `group_has_non_main_replica - group_has_main_replica`
cross_parallelization_group_loads = shards_in_this_group - group_has_main_replica
# Filter out shards that don't belong to this group
shard_to_ranks = {k: v for k, v in shard_to_ranks.items() if k in shards_in_this_group}
shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards), cross_parallelization_group_loads
)
return ShardDistribution(
shard_to_saving_rank, shards_in_this_group, shard_to_metadata, shard_to_ranks
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import os
from typing import Union
from torch.distributed.checkpoint import Metadata
from hyckpt_torch import FileSystemReader
class CachedMetadataFileSystemReader(FileSystemReader):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def __init__(self, path: Union[str, os.PathLike]) -> None:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super().__init__(path=path)
self._cached_metadata = None
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import logging
from pathlib import Path
from typing import List, Tuple
import torch
from torch import multiprocessing as mp
from hyckpt_torch import _write_items
from megatron.core.dist_checkpointing.strategies.async_utils import _disable_gc
from megatron.core.dist_checkpointing.strategies.filesystem_async import _process_memory
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
@staticmethod
@_disable_gc()
def write_preloaded_data(
transform_list,
local_proc_idx: int,
write_bucket: WriteBucket,
results_queue: mp.SimpleQueue,
count_queue: mp.JoinableQueue,
use_fsync: bool,
) -> None:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the `queue`
"""
logger = logging.getLogger(__name__)
logger.debug(f'{local_proc_idx} started')
mem_before = _process_memory()
rank = torch.distributed.get_rank()
local_results = []
try:
local_results = _write_items(write_bucket)
'''
for result in local_results:
if hasattr(result.index, 'index'):
from dataclasses import replace
new_index = replace(result.index, index=rank)
new_result = replace(result, index=new_index)
'''
local_output = (local_proc_idx, local_results)
except Exception as e:
logger.debug(f'{local_proc_idx} failed')
local_output = (local_proc_idx, e)
results_queue.put(local_output)
# Signal this process is done.
count_queue.get()
count_queue.task_done()
mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before},"
f" before: {mem_before}, after: {mem_after}"
)
@staticmethod
def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]:
"""Preload tensors in state_dict to host memory through CPU memory
Args:
write_buckets(List): List of `WriteBucket`,
which includes what to be saved in a checkpoint
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result = []
for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=False)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
if non_blocking:
torch.cuda.synchronize()
return result
import logging
from typing import Optional
from megatron.core.dist_checkpointing.exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
)
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.strategies.fully_parallel import distribute_main_replicas_with_precomputed_distribution
logger = logging.getLogger(__name__)
class FullyParallelLoadStrategyWrapper():
def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
) -> Optional[ShardDistribution]:
"""Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
ShardDistribution (optional): the computed loading distribution
"""
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* load parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply load parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
return precomputed_distribution
import io
import os
import pickle
import warnings
from collections import ChainMap, defaultdict
from dataclasses import dataclass
from itertools import product
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import torch
from packaging.version import Version as PkgVersion
from torch.distributed import checkpoint
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import Shard
from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties
from torch.distributed.checkpoint import (
BytesStorageMetadata,
DefaultLoadPlanner,
DefaultSavePlanner,
#FileSystemReader,
FileSystemWriter,
LoadPlan,
Metadata,
ReadItem,
SavePlan,
TensorStorageMetadata,
WriteItem,
)
from hyckpt_torch import FileSystemReader
from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict
from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.planner_helpers import _create_write_items
from megatron.core.utils import get_torch_version, is_torch_min_version
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import nested_values
from megatron.core.dist_checkpointing.mapping import (
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
StateDict,
is_main_replica,
)
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncRequest
from megatron.core.dist_checkpointing.strategies.base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
StrategyAction,
register_default_strategy,
)
from megatron.core.dist_checkpointing.strategies.cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync
from megatron.core.dist_checkpointing.strategies.resharding import (
TensorReformulationMetadata,
apply_nd_flattened_tensors_reformulation,
is_nd_flattened_tensor,
nd_flattened_tensor_reformulated_global_shape,
restore_nd_flattened_tensors_formulation,
)
from megatron.core.dist_checkpointing.strategies.state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan
from megatron.core.dist_checkpointing.strategies.torch import (
_replace_state_dict_keys_with_sharded_keys,
mcore_to_pyt_state_dict,
MCoreLoadPlanner,
_replace_sharded_keys_with_state_dict_keys,
_restore_dict_types,
_unwrap_pyt_sharded_tensor
)
def get_reformulation_metadata(
sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> Dict[str, TensorReformulationMetadata]:
"""Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory
Returns:
Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every
N-D flattened tensor from the sharded_state_dict to its original global shape
as stored in `mcore_data` in the checkpoint.
"""
ckpt_metadata = FileSystemReader(checkpoint_dir).read_metadata()
reformulation_metadata = {}
for sh_ten in nested_values(sharded_state_dict):
if not is_nd_flattened_tensor(sh_ten):
continue
try:
ckpt_global_shape = ckpt_metadata.mcore_data[sh_ten.key][
'nd_reformulated_orig_global_shape'
]
except KeyError as e:
if len(sh_ten.global_shape) == 1:
warnings.warn(
f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. '
'Skip metadata reformulation.'
)
continue
raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
) from e
reformulation_metadata[sh_ten.key] = TensorReformulationMetadata(
ckpt_global_shape, ckpt_metadata.state_dict_metadata[sh_ten.key].size
)
return reformulation_metadata
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
def __init__(self):
self.cached_global_metadata: Optional[Metadata] = None
super().__init__()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, reformulation_metadata
)
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors = False
for sh_ten in nested_values(sharded_state_dict):
if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata:
has_legacy_1d_flattened_tensors = True
break
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch
]
allow_shape_mismatch_sharded_tensors = {
sh_ten.key: sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and sh_ten.allow_shape_mismatch
}
orig_sharded_state_dict = sharded_state_dict
# MCore state dict to PyT Distributed compatible
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
)
pyt_state_dict = mcore_to_pyt_state_dict(
sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
fsr = CachedMetadataFileSystemReader(checkpoint_dir)
checkpoint.load_state_dict(
pyt_state_dict,
fsr,
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
),
)
self.cached_global_metadata = (
fsr.read_metadata()
) # no storage interaction thanks to caching
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict = {
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
_restore_dict_types(mcore_state_dict, orig_sharded_state_dict)
# Apply N-D tensors resharding postprocessing
mcore_state_dict = restore_nd_flattened_tensors_formulation(
mcore_state_dict, formulation_restore_data
)
return mcore_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None):
"""Uses tensors metadata stored in the metadata file."""
if metadata is None:
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
mcore_data = getattr(metadata, 'mcore_data', {})
sharded_metadata = {}
for k, tp in metadata.state_dict_metadata.items():
if not isinstance(tp, TensorStorageMetadata):
continue # load only tensors
nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape')
if nd_orig_global_shape is None:
# Regular tensor
sharded_metadata[k] = ShardedTensor.from_rank_offsets(
k, torch.empty(tp.size, **tp.properties.__dict__, device='meta')
).without_data()
else:
# N-D flattened tensor
unflat_ten = torch.empty(
nd_orig_global_shape, **tp.properties.__dict__, device='meta'
)
flat_ten = unflat_ten.flatten()
sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat(
k,
flat_ten,
unflat_ten.shape,
flattened_range=slice(0, unflat_ten.numel()), # whole slice
).without_data()
return sharded_metadata
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
"""Uses tensors and objects metadata stored in the metadata file."""
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
sharded_metadata = {}
for metadata_key, storage_metadata in metadata.state_dict_metadata.items():
if not isinstance(storage_metadata, BytesStorageMetadata):
continue
sh_obj = ShardedObject.empty_from_unique_key(metadata_key)
sharded_metadata[sh_obj.unique_key] = sh_obj
sharded_metadata.update(self.load_tensors_metadata(checkpoint_dir, metadata))
return sharded_metadata
def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str):
"""Removes checkpoint files whose keys have the given prefix.
Performs the following steps:
1. checks whether there are files that start with the key_prefix
2. loads metadata
3. removes all entries from the metadata that start with the key_prefix
4. resaves the new metadata and removes the old metadata
5. removes the relevant files
"""
assert is_torch_min_version(
"2.3.0"
), f'torch >= 2.3.0 is required for remove_sharded_tensors'
distckpt_files = [f for f in os.listdir(checkpoint_dir) if f.endswith("distcp")]
files_to_remove = [f for f in distckpt_files if f.startswith(key_prefix)]
if not files_to_remove:
warnings.warn(
f'There are no files in {checkpoint_dir} that begin with "{key_prefix}".'
f' Skipping removal.'
)
return
fs_reader = FileSystemReader(checkpoint_dir)
original_metadata = fs_reader.read_metadata()
new_state_dict_metadata = {}
new_planner_data = {}
new_storage_data = {}
for k in original_metadata.state_dict_metadata.keys():
if k.startswith(key_prefix):
continue
new_state_dict_metadata[k] = original_metadata.state_dict_metadata[k]
for k in original_metadata.planner_data.keys():
if k.startswith(key_prefix):
continue
new_planner_data[k] = original_metadata.planner_data[k]
for k in original_metadata.storage_data.keys():
if k.fqn.startswith(key_prefix):
continue
new_storage_data[k] = original_metadata.storage_data[k]
metadata = Metadata(
state_dict_metadata=new_state_dict_metadata,
planner_data=new_planner_data,
storage_data=new_storage_data,
)
fs_writer = FileSystemWriter(checkpoint_dir)
metadata_filename = cast(Path, fs_writer.fs.concat_path(fs_writer.path, _metadata_fn))
tmp_path = cast(
metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.tmp")
)
old_path = cast(
metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.bck")
)
## save the new metadata
with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
try:
os.fsync(metadata_file.fileno())
except AttributeError:
os.sync()
## move the old metadata
fs_writer.fs.rename(fs_writer.metadata_path, old_path)
try:
## rename the new metadata
fs_writer.fs.rename(tmp_path, fs_writer.metadata_path)
## finally, remove the files we want to drop
for f in files_to_remove:
fs_writer.fs.rm_file(checkpoint_dir / f)
except Exception as e:
fs_writer.fs.rename(old_path, fs_writer.metadata_path)
raise e
else:
fs_writer.fs.rm_file(old_path)
def can_handle_sharded_objects(self):
return True
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
\ No newline at end of file
import logging
import numpy as np
import torch
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.mapping import is_main_replica
logger = logging.getLogger(__name__)
def _compute_shards_access(rank_sharding):
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device="cpu"
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
if 'norm' in sharding.key:
shard_access_cnt[sharding.local_chunk_offset_in_global()] = 1
else:
shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
continue
if all_slices and 'norm' in sharding.key:
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
expected_size = np.product(local_shape)
if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
raise CheckpointingException(
f"Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}"
)
class _BaseDataParallel():
def backward_dw(self, *inputs, **kwargs):
"""
Calls the wrapped module's backward_dw() method.
"""
return self.module.backward_dw(*inputs, **kwargs)
\ No newline at end of file
import torch
from megatron.training import get_args
from megatron.core.transformer.cuda_graphs import is_graph_capturing
class DistributedDataParallel():
def _make_backward_post_hook(self, param: torch.nn.Parameter):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
# support dualpipev
if not get_args().gradient_accumulation_fusion or not get_args().delay_wgrad_compute:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
self.param_to_bucket_group[param].register_grad_ready(param)
return hook
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional
import torch
try:
from torch.distributed._tensor import DTensor, distribute_tensor
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
from megatron.core import mpu
from megatron.core import parallel_state
from megatron.core.utils import get_model_config
from megatron.training.global_vars import get_args
from ...training.edgc_utils import Utils
from megatron.core.distributed.finalize_model_grads import (
_allreduce_conditional_embedding_grads,
_allreduce_non_tensor_model_parallel_grads,
_allreduce_word_embedding_grads,
_allreduce_position_embedding_grads,
reset_model_temporary_tensors,
_update_router_expert_bias
)
def finalize_model_grads(
model: List[torch.nn.Module],
num_tokens: Optional[torch.Tensor] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
args = get_args()
config = get_model_config(model[0])
if pg_collection is not None:
assert hasattr(pg_collection, 'tp')
assert hasattr(pg_collection, 'pp')
assert hasattr(pg_collection, 'embd'), (
"pg_collection must have a embd. In previous version, it is used default "
"`parallel_state.default_embedding_ranks` to create the process group."
" If you are using the default process group, please use"
" `parallel_state.get_embedding_group()` "
"If you don't need embd_group, you need to explicitly set it to None."
)
assert hasattr(pg_collection, 'pos_embd'), (
"pg_collection must have a pos_embd. In previous version, it is used default "
"`parallel_state.default_position_embedding_ranks` to create the process group."
" If you are using the default process group, please use "
" `parallel_state.get_position_embedding_group()` "
"If you don't need pos_embd_group, you need to explicitly set it to None."
)
assert hasattr(pg_collection, 'dp_cp')
tp_group = pg_collection.tp
pp_group = pg_collection.pp
embd_group = pg_collection.embd
pos_emb_group = pg_collection.pos_embd
dp_cp_group = pg_collection.dp_cp
else:
tp_group = parallel_state.get_tensor_model_parallel_group()
pp_group = parallel_state.get_pipeline_model_parallel_group()
embd_group = parallel_state.get_embedding_group(check_initialized=False)
pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
def _handle_all_reduce_time_start(args, config):
if args.all_reduce_time:
config.timers('DP_time', log_level=0).start()
def _handle_all_reduce_time_end(args, config):
if args.all_reduce_time:
config.timers('DP_time').stop()
def _update_gradient_compression_state(args):
if args.max_rank is None:
if args.is_loading_checkpoint:
if args.curr_iteration >= (args.latest_iteration + 12):
args.grad_comp_enabled = True
else:
if args.curr_iteration >= 12:
args.grad_comp_enabled = True
else:
if args.curr_iteration > args.warm_up_train_iter:
if args.begin_max_rank:
args.grad_comp_enabled = not (args.is_loading_checkpoint and (
len(Utils.mapped_rank) == 0 or Utils.mapped_rank[-1] is None))
elif (args.curr_iteration % args.rank_adjust_window_size == 1) and (
args.curr_iteration != (args.latest_iteration + 1)):
args.grad_comp_enabled = True
if not mpu.is_pipeline_first_stage():
_update_mapped_rank_based_on_final_rank(args)
elif args.begin_warm_up:
args.grad_comp_enabled = False
args.begin_warm_up = False
args.grad_comp = args.grad_comp_enabled
def _update_mapped_rank_based_on_final_rank(args):
if len(Utils.mapped_rank) >= 2:
if args.final_rank is None:
args.grad_comp_enabled = False
elif args.final_rank != Utils.mapped_rank[-2]:
if args.final_rank is not None:
args.mapped_rank = args.final_rank
else:
args.grad_comp_enabled = False
else:
args.mapped_rank = args.final_rank
def _get_find_rank(args):
"""Helper to determine rank when finding rank upper limit."""
if args.mapped_rank is not None:
return int(args.mapped_rank)
if args.is_loading_checkpoint:
return int(Utils.mapped_rank[-1] if Utils.mapped_rank else args.max_rank)
return int(args.max_rank)
def _get_adaptive_rank(args):
"""Helper to determine rank during adaptive compression."""
if args.is_loading_checkpoint:
delta_iter = args.curr_iteration - args.latest_iteration
else:
delta_iter = args.curr_iteration
return 2 ** int((delta_iter - 9) / 3)
def compressor_update(args):
if not args.enable_dynamic_grad_comp or not args.grad_comp:
args.compressor = None
return
if args.fp16:
compression_dtype = torch.float16
elif args.bf16:
compression_dtype = torch.bfloat16
else:
compression_dtype = torch.float32
rank = _get_find_rank(args) if args.find_rank_upper_limit else _get_adaptive_rank(args)
if args.pre_rank is not None:
if args.pre_rank == rank:
args.compressor.begin_iteration(args.curr_iteration)
return
args.pre_rank = rank
from .power_sgd import PowerSGDCompressor
args.compressor = PowerSGDCompressor(
ef_layout_manager=args.ef_manager,
rank=rank,
compression_dtype=compression_dtype
)
args.compressor.begin_iteration(args.curr_iteration)
if args.enable_dynamic_grad_comp and not args.overlap_grad_reduce:
_handle_all_reduce_time_start(args, config)
for model_chunk in model:
if args.enable_dynamic_grad_comp:
_update_gradient_compression_state(args)
compressor_update(args)
model_chunk.finish_grad_sync()
if args.enable_dynamic_grad_comp:
if args.begin_max_rank:
args.begin_max_rank = False
if not args.overlap_grad_reduce:
_handle_all_reduce_time_end(args, config)
if args.enable_dynamic_grad_comp:
if args.all_reduce_time:
args.params_all_reduce_time = config.timers('DP_time').elapsed(reset=True) * 1000.0
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_conditional_embedding_grads(model, config, pp_group)
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce').stop()
# All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules.
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_non_tensor_model_parallel_grads(model, config, tp_group)
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_word_embedding_grads(model, config, embd_group, pp_group)
_allreduce_position_embedding_grads(model, config, pos_emb_group, pp_group)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
reset_model_temporary_tensors(config, model)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
assert not isinstance(pp_group, list)
last_rank = get_pp_last_rank(pp_group)
torch.distributed.broadcast(num_tokens, src=last_rank, group=pp_group)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=dp_cp_group)
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment