Binary files ./megatron/core/datasets/helpers.cpython-310-x86_64-linux-gnu.so and ../megatron-lm/megatron/core/datasets/helpers.cpython-310-x86_64-linux-gnu.so differ Binary files ./megatron/core/datasets/__pycache__/blended_dataset.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/blended_dataset.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/blended_megatron_dataset_builder.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/blended_megatron_dataset_builder.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/blended_megatron_dataset_config.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/blended_megatron_dataset_config.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/gpt_dataset.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/gpt_dataset.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/indexed_dataset.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/indexed_dataset.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/megatron_dataset.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/megatron_dataset.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/megatron_tokenizer.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/megatron_tokenizer.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/utils.cpython-310.pyc differ Binary files ./megatron/core/datasets/__pycache__/utils_s3.cpython-310.pyc and ../megatron-lm/megatron/core/datasets/__pycache__/utils_s3.cpython-310.pyc differ diff -rN ./megatron/core/dist_checkpointing/dict_utils.py ../megatron-lm/megatron/core/dist_checkpointing/dict_utils.py 10c10 < from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union --- > from typing import Any, Callable, Iterable, Optional, Tuple, Union 12d11 < import numpy as np 15,16d13 < U, V = TypeVar("U"), TypeVar("V") < 51c48 < elif isinstance(x, list): # type: ignore --- > elif isinstance(x, list): 94,95c91 < elif isinstance(x1, list) or isinstance(x1, tuple) or isinstance(x1, np.ndarray): < assert type(x1) == type(x2) --- > elif isinstance(x1, list) and isinstance(x2, list): 108,114d103 < # TODO: change with concrete type that has both replica_id and data attrs < elif hasattr(x1, 'replica_id') and hasattr(x2, 'replica_id'): < assert type(x1) == type(x2) < only_left, only_right, mismatch = diff( < x1.data, x2.data, prefix + (type(x1),) < ) # type: ignore < _is_mismatch = False 148c137 < except: --- > except Exception: 187c176 < def dict_list_map_inplace(f: Callable[[U], V], x: Union[Dict, List, U]): --- > def dict_list_map_inplace(f: Callable, x: Union[dict, list]): 199c188 < def dict_list_map_outplace(f: Callable[[U], V], x: Union[Dict, List, U]) -> Union[Dict, List, V]: --- > def dict_list_map_outplace(f: Callable, x: Union[dict, list]): 209c198 < def merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ()): --- > def merge(x1: dict, x2: dict, key: Tuple[str, ...] = ()): 220,221c209 < f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, ' < f'encountered at level {key})' --- > f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at level {key})' 227,228c215 < f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` ' < f'(at level {key})' --- > f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` (at level {key})' diff -rN ./megatron/core/dist_checkpointing/exchange_utils.py ../megatron-lm/megatron/core/dist_checkpointing/exchange_utils.py 1,519d0 < # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. < < """Utilities for exchanging data between ranks.""" < < import logging < from collections import defaultdict < from functools import reduce < from itertools import zip_longest < from time import time < from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast < < import numpy as np < import torch < < from .core import CheckpointingException < from .dict_utils import nested_values < from .mapping import ShardedStateDict, ShardedTensor, is_main_replica < from .utils import _sharded_tensor_shard_id, _ShardId < < # TODO: remove TE references once the TE bug is fixed < # Check if Transformer Engine has Float8Tensor class < HAVE_TE_FLOAT8TENSOR = False < try: < from transformer_engine.pytorch.float8_tensor import Float8Tensor < < HAVE_TE_FLOAT8TENSOR = True < except (ImportError, ModuleNotFoundError): < # Float8Tensor not found < pass < < < def is_float8tensor(tensor: torch.Tensor) -> bool: < """Check if a tensor is a Transformer Engine Float8Tensor""" < return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) < < < logger = logging.getLogger(__name__) < < < class ShardDistribution(NamedTuple): < """Represents a distribution of ShardedTensors. < < Given distribution is valid only for a specific parallelization group, < which is implicit here (not referenced by this class). < < Args: < main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold < the main replica for a given shard < shards_in_this_group (Set[_ShardId]): which shards have a main replica < in this parallelization group < shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor < identifier to the original ShardedTensor < all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks < need a given shard in a given parallelization group < < """ < < main_rank_for_shard: Dict[_ShardId, int] < shards_in_this_group: Set[_ShardId] < shard_to_metadata: Dict[_ShardId, ShardedTensor] < all_ranks_for_shard: Dict[_ShardId, List[int]] < < < def _shard_size(sh_ten: ShardedTensor): < """Returns size in bytes of a given sharded tensor.""" < if sh_ten.flattened_range is None: < numel = np.product(sh_ten.local_shape) < else: < numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start < return numel * torch._utils._element_size(sh_ten.dtype) < < < def _get_empty_tensor_for_exchange( < shard_id: _ShardId, < needed_shards: Dict[_ShardId, ShardedTensor], < unneeded_shards: Dict[_ShardId, ShardedTensor], < loaded_tensors: Dict[_ShardId, torch.Tensor], < ) -> Tuple[torch.Tensor, Optional[torch.device]]: < """Determines the empty tensor to use for exchange. < < If shard_id is needed by this rank, it will be in the `unloaded_shards`. < Otherwise, the metadata for this tensor can be found in `shard_to_metadata` < < Args: < shard_id (_ShardId): shard_id that will be exchanged < needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids < to metadata for shards needed by this rank < unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids < to metadata for shards that can be discarded after exchange < loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors < are placed in < < Returns: < Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged, < and the device of the original state dict tensor (if there was any) < """ < local_unloaded_sh_ten = needed_shards.get(shard_id) < if local_unloaded_sh_ten is None: < orig_device = None # this tensor will be discarded anyway < sh_ten = unneeded_shards[shard_id] < if sh_ten.data is None: < sh_ten.init_data('cuda') < tensor = sh_ten.data < sh_ten.data = None # won't be used. free memory < else: < tensor = sh_ten.data < if tensor.device.type == 'cpu': < tensor = torch.empty_like(tensor, device='cuda') < else: < local_unloaded_sh_ten.init_data('cuda') < orig_device = local_unloaded_sh_ten.data.device < tensor = local_unloaded_sh_ten.data < if tensor.device.type == 'cpu': < tensor = torch.empty_like(tensor, device='cuda') < loaded_tensors[shard_id] = tensor < return tensor, orig_device < < < T = TypeVar('T') < < < def distribute_shards_to_ranks( < shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int < ) -> Dict[T, int]: < """Computes uniform distribution of workload across ranks, based on sizes. < < Currently, the assignment is greedy, based on: < 1. Firstly, the coverage of each shard < (how many ranks the shard is available on; lower coverage is assigned first) < 2. Secondly, the size of each shard (larger size is assigned first) < 3. Finally, shard id for differentiation. < < Third step is added because we rely on the fact that < the assignment is deterministic on all ranks. < < Args: < shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards < shard_to_size (Dict[T, int]): sizes of each shard < num_ranks (int): number of ranks in the parallelization group < < Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work < to achieve maximal uniformity) < """ < shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()} < shard_to_saving_rank = {} < rank_sizes = [(0, rank) for rank in range(num_ranks)] < < # start from tensors of lowest coverage, then go by tensor size from largest (hence minus size) < for shard_id, shard_ranks in sorted( < shard_to_ranks.items(), < key=lambda sh_id_ranks: ( < len(sh_id_ranks[1]), < -shard_to_size[sh_id_ranks[0]], < sh_id_ranks[0], < ), < ): < # assign greedily to the least occupied rank < size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks) < < shard_to_saving_rank[shard_id] = rank < rank_sizes[rank] = (size + shard_to_size[shard_id], rank) < < logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}') < < return shard_to_saving_rank < < < def determine_main_replica_uniform_distribution( < sharded_state_dict: ShardedStateDict, < parallelization_group: torch.distributed.ProcessGroup, < ignore_groups: bool = False, < ) -> Optional[ShardDistribution]: < """Computes the save distribution. < < Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution` < which applies the computed save distribution. < < We rely on the fact that the assignment algorithm is deterministic on all ranks, < so there is no extra communication needed after metadata exchange. < < Args: < sharded_state_dict (ShardedStateDict): state dict to compute the distribution of < parallelization_group (ProcessGroup): distribution will be computed < within this process group < ignore_groups (bool, optional): whether the distribution defines groups. < This option is primarily used during loading, as it ensures that all replicas, < including non-main ones, are loaded by this parallelization group < Defaults to False. < < Returns (ShardDistribution, optional): distribution that can be used to apply the < parallelization. Returns None if the process_group is trivial (1 rank) < < """ < group_size = torch.distributed.get_world_size(group=parallelization_group) < if group_size <= 1: < return < local_shards = list( < sh_base < for sh_base in nested_values(sharded_state_dict) < if isinstance(sh_base, ShardedTensor) < ) < local_shards_no_data = [ten.without_data() for ten in local_shards] < < all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group) < torch.distributed.all_gather_object( < all_shards, local_shards_no_data, group=parallelization_group < ) < < shard_to_ranks = defaultdict(list) < shard_to_size = {} < shard_to_metadata = {} < shards_in_this_parallelization_group: Set[_ShardId] = set() < for rank, rank_shards in enumerate(all_shards): < for sh_ten in rank_shards: < shard_id = _sharded_tensor_shard_id(sh_ten) < shard_to_ranks[shard_id].append(rank) < if shard_id not in shard_to_size: < shard_to_size[shard_id] = _shard_size(sh_ten) < shard_to_metadata[shard_id] = sh_ten < if is_main_replica(sh_ten.replica_id) or ignore_groups: < shards_in_this_parallelization_group.add(shard_id) < < shard_to_ranks = { < k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group < } < < shard_to_saving_rank = distribute_shards_to_ranks( < shard_to_ranks, shard_to_size, len(all_shards) < ) < < return ShardDistribution( < shard_to_saving_rank, < shards_in_this_parallelization_group, < shard_to_metadata, < shard_to_ranks, < ) < < < @torch.no_grad() < def exchange_loaded_tensors_gather_rounds( < loaded_tensors: Dict[_ShardId, torch.Tensor], < unloaded_shards: Dict[_ShardId, ShardedTensor], < shard_distribution: ShardDistribution = None, < parallelization_group: Optional[torch.distributed.ProcessGroup] = None, < ) -> Dict[_ShardId, torch.Tensor]: < """Exchange the tensors loaded by different ranks with several all_gather calls. < < Groups tensors by dtype, divide tensors that will be exchanged into rounds < and execute all_gather for tensors from each round. < < Note: the loading is distributed across ranks based on total loaded size < in bytes, so there is no guarantee that number of rounds needed for each < rank will be similar, which might result in a lot of almost empty < all_gathers. The solution would be to group all tensors into a one < bytes tensor and do a single all_gather (with similarly sized messages). < < Args: < loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor < shard ids to tensors already loaded by this rank. < unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor < shard ids to ShardedTensors that aren't loaded yet. < shard_distribution (ShardDistribution): distribution of all shards < parallelization_group (ProcessGroup, optional): process group used for load < distribution. Tensors will be exchanged within this group < < Returns: < Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors < needed by this rank to load a given state dict. Includes < previously loaded tensors (from `loaded_tensors` input) < """ < main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution < local_rank = torch.distributed.get_rank(group=parallelization_group) < < all_loaded_tensors = dict(loaded_tensors) < < # Group by dtype so that we all_gather tensors of the same dtype < for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str): < < start = time() < # shards_by_rank maps rank to tensors loaded by this rank < shards_by_rank: List[List[torch.Tensor]] = [ < [] for _ in range(torch.distributed.get_world_size(group=parallelization_group)) < ] < for shard_id, rank in main_rank_for_shard.items(): < if len(all_ranks_for_shard[shard_id]) == 1: < assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( < f'When there is only 1 ranks that needs a given shard,' < f' it should be the loading rank.' < f' Got: needs [{all_ranks_for_shard[shard_id][0]}]' < f' vs loads [{main_rank_for_shard[shard_id]}]' < ) < # Skipping the exchange since only the loading rank needs this tensor < # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` < # case, e.g. P2P exchange. Currently handling this case saves most of the < # work though. < continue < if shard_to_metadata[shard_id].dtype == dtype: < shards_by_rank[rank].append(shard_id) < < # Transpose `shards_by_rank` to form exchange rounds < shards_by_round = zip_longest(*shards_by_rank, fillvalue=None) < for round_idx, round_shard_ids in enumerate(shards_by_round): < round_tensors = [] < orig_devices = {} < for rank, shard_id in enumerate(round_shard_ids): < if shard_id is None: < # if no more useful data, the given rank will exchange empty tensor < local_ten = torch.empty(0, dtype=dtype, device='cuda') < orig_device = None < else: < assert isinstance(shard_id, tuple), type(shard_id) < if rank == local_rank: < assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) < orig_device = all_loaded_tensors[shard_id] < all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda() < local_ten = all_loaded_tensors[shard_id] < else: < local_ten, orig_device = _get_empty_tensor_for_exchange( < shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors < ) < # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 < # It's ok to keep the nominal dtype after exchange, because TE will handle < # this during state dict load. < # TODO: remove it once the bug is fixed < if is_float8tensor(local_ten): < local_ten = local_ten.from_float8() < all_loaded_tensors[shard_id] = local_ten < < round_tensors.append(local_ten) < if orig_device is not None: < orig_devices[shard_id] = orig_device < < torch.distributed.all_gather( < list(round_tensors), < round_tensors[local_rank], < group=parallelization_group, < async_op=False, < ) < < # Move tensors back to CPU if originally was on CPU < for shard_id, orig_device in orig_devices.items(): < all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device) < < del round_tensors # remove tensor references < < end = time() < if torch.distributed.get_rank() == 0: < logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s') < < return all_loaded_tensors < < < def exchange_loaded_tensors_gather_object( < loaded_tensors: Dict[_ShardId, torch.Tensor], < unloaded_shards: Dict[_ShardId, ShardedTensor], < shard_distribution: ShardDistribution, < parallelization_group: Optional[torch.distributed.ProcessGroup] = None, < ) -> Dict[_ShardId, torch.Tensor]: < """Exchange the tensors loaded by different ranks with a simple all_gather_object call. < < This version can be used for debugging purposes do to its simplistic < implementation. Shouldn't be used if performance is important. < < Args: < loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor < shard ids to tensors already loaded by this rank. < unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor < shard ids to ShardedTensors that aren't loaded yet. < shard_distribution (ShardDistribution): distribution of all shards < parallelization_group (ProcessGroup, optional): process group used for load < distribution. Tensors will be exchanged within this group < < Returns: < Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors < needed by this rank to load a given state dict. Includes < previously loaded tensors (from `loaded_tensors` input) < < """ < all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group) < torch.distributed.all_gather_object( < all_loaded_tensors_list, loaded_tensors, group=parallelization_group < ) < all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list) < all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list) < < # Error checks < if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)): < err_msg = 'Duplicate shard ids loaded by different ranks' < if torch.distributed.get_rank() == 0: < logger.error( < f'{err_msg}. Shards ids by rank:' < f' {[lt.keys() for lt in all_loaded_tensors_list]}' < ) < raise CheckpointingException(err_msg) < < return all_loaded_tensors < < < @torch.no_grad() < def exchange_loaded_tensors_broadcast( < loaded_tensors: Dict[_ShardId, torch.Tensor], < unloaded_shards: Dict[_ShardId, ShardedTensor], < shard_distribution: ShardDistribution, < parallelization_group: Optional[torch.distributed.ProcessGroup] = None, < ) -> Dict[_ShardId, torch.Tensor]: < """Exchange the tensors loaded by different ranks by a series of broadcasts. < < For each rank for each loaded tensor do a broadcast to the whole group. < A reasonable tradeoff in terms of performance and simplicity. < < Args: < loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor < shard ids to tensors already loaded by this rank. < unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor < shard ids to ShardedTensors that aren't loaded yet. < shard_distribution (ShardDistribution): distribution of all shards < parallelization_group (ProcessGroup, optional): process group used for load < distribution. Tensors will be exchanged within this group < < Returns: < Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors < needed by this rank to load a given state dict. Includes < previously loaded tensors (from `loaded_tensors` input) < """ < main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution < local_rank = torch.distributed.get_rank(group=parallelization_group) < < all_loaded_tensors = dict(loaded_tensors) < < start = time() < < for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()): < if len(all_ranks_for_shard[shard_id]) == 1: < assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( < f'When there is only 1 ranks that needs a given shard,' < f' it should be the loading rank.' < f'Got: needs [{all_ranks_for_shard[shard_id][0]}]' < f' vs loads [{main_rank_for_shard[shard_id]}]' < ) < # Skipping the exchange since only the loading rank needs this tensor < # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case, < # e.g. P2P exchange. Currently handling this case saves most of the work though. < continue < if rank == local_rank: < assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) < orig_device = all_loaded_tensors[shard_id].device < local_ten = all_loaded_tensors[shard_id].cuda() < else: < local_ten, orig_device = _get_empty_tensor_for_exchange( < shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors < ) < < # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 < # It's ok to keep the nominal dtype after exchange, because TE will handle < # this during state dict load. < # TODO: remove it once the bug is fixed < if is_float8tensor(local_ten): < local_ten = local_ten.from_float8() < all_loaded_tensors[shard_id] = local_ten < < global_src_rank = ( < rank < if parallelization_group == None < else torch.distributed.get_global_rank(parallelization_group, rank) < ) < # We can do async_op=True only if there is no CPU-copy follow-up < torch.distributed.broadcast( < local_ten, < src=global_src_rank, < group=parallelization_group, < async_op=orig_device is None, < ) < # Move tensor back to CPU if originally was on CPU < if orig_device is not None: < all_loaded_tensors[shard_id] = local_ten.to(orig_device) < del local_ten < < end = time() < if torch.distributed.get_rank() == 0: < logger.debug(f'exchange broadcast schedule took {end - start}s') < < return all_loaded_tensors < < < def exchange_by_distribution( < loaded_tensors: Dict[_ShardId, torch.Tensor], < unloaded_shards: Dict[_ShardId, ShardedTensor], < shard_distribution: ShardDistribution = None, < parallelization_group: Optional[torch.distributed.ProcessGroup] = None, < exchange_algo='broadcast', < ) -> Dict[_ShardId, torch.Tensor]: < """Exchange tensors loaded by different ranks using the specified exchange_algo. < < Args: < loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor < shard ids to tensors already loaded by this rank. < unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor < shard ids to ShardedTensors that aren't loaded yet. < shard_distribution (ShardDistribution): distribution of all shards < parallelization_group (ProcessGroup, optional): process group used for load < distribution. Tensors will be exchanged within this group < exchange_algo (str): The algorithm used for performing exchanges. < Defaults to 'broadcast'. < < Returns: < Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors < needed by this rank to load a given state dict. Includes < previously loaded tensors (from `loaded_tensors` input) < """ < < if exchange_algo == 'gather_object': < exchange_fn = exchange_loaded_tensors_gather_object < elif exchange_algo == 'gather_rounds': < exchange_fn = exchange_loaded_tensors_gather_rounds < elif exchange_algo == 'broadcast': < exchange_fn = exchange_loaded_tensors_broadcast < else: < raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}') < return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group) diff -rN ./megatron/core/dist_checkpointing/mapping.py ../megatron-lm/megatron/core/dist_checkpointing/mapping.py 13c13 < from typing import Any, Callable, Dict, List, Optional, Tuple, Union --- > from typing import Any, Callable, Dict, Optional, Tuple, Union 19c19 < from .dict_utils import dict_list_map_inplace --- > from .dict_utils import dict_list_map_inplace, dict_list_map_outplace 32,33d31 < """Base class for ShardedTensor and ShardedStateDict.""" < 44d41 < """Returns a new ShardedBase instance with data=None.""" 61,62c58 < global_offset: offset of a local tensor in a global tensor, < specified in number of tensor elements --- > global_offset: offset of a local tensor in a global tensor, specified in number of tensor elements 64,75c60,63 < replica_id: indicates given local tensor's replication wrt. < local tensors in different processes < prepend_axis_num: number of axes prepended to the local tensor to < reflect global tensor shape. The behavior is similar to < unsqueezing the local tensor. < allow_shape_mismatch: if True, during loading, the global shape of < a stored tensor does not have to match the expected global shape. < Useful for representing tensors with flexible shape, < e.g. padded. < flattened_range: specifies a slice that should be applied to a < flattened tensor with `local_shape` in order to get < the tensor stored as `data` --- > replica_id: indicates given local tensor's replication wrt. local tensors in different processes > prepend_axis_num: number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor. > allow_shape_mismatch: if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded. > flattened_range: specifies a slice that should be applied to a flattened tensor with `local_shape` in order to get the tensor stored as `data` 132,133c120 < f'Local shape together with `prepend_axis_num` dimensions should be ' < f'equal to global shape dimensions for {self}' --- > f'Local shape together with `prepend_axis_num` dimensions should be equal to global shape dimensions for {self}' 148,151d134 < """ < Returns a tuple of int and slice objects representing a slice of the < global tensor that this ShardedTensor corresponds to. < """ 166,169d148 < """ < Returns a tuple of np.ndarrays representing the coordinates of the global tensor < that this ShardedTensor corresponds to. < """ 188,191d166 < """ < Returns a tuple of np.ndarrays representing the coordinates of the local tensor < that this ShardedTensor corresponds to. < """ 217,219d191 < """ < Returns the maximum allowed chunks for this ShardedTensor. < """ 249,252c221 < rank_offsets (Tuple[int, int, int]): each tuple < (axis, axis_rank_offset, axis_fragm) says that if < global tensor is divided into `axis_fragm` fragment along `axis` < axis, then local tensor data corresponds to the `axis_rank_offset` chunk. --- > rank_offsets (Tuple[int, int, int]): each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into `axis_fragm` fragment along `axis` axis, then local tensor data corresponds to the `axis_rank_offset` chunk. 334,335c303 < f'Flattened ShardedTensor data length ({data.numel()}) must meet the ' < f'slice length: {flattened_range.stop - flattened_range.start}' --- > f'Flattened ShardedTensor data length ({data.numel()}) must meet the slice length: {flattened_range.stop - flattened_range.start}' 345,354d312 < """ < Initialize the tensor data of this ShardedTensor. < < Only called if `data` attribute is None. < < Args: < device (Union[str, torch.device]): device to place the tensor on < init_fn (Callable, optional): function to use to initialize the tensor. < Defaults to `torch.empty`. < """ 361,486d318 < def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']: < """This is an analogue of torch.narrow for ShardedTensors. < < Narrowing assumes that we narrow a local tensor on each rank. < This has consequences on local_shape, global_shape, global_offset, etc. < < Args: < dim (int): dimension to narrow. Doesn't include prepended axes. < start (int): start element < length (int): length of the slice < < Returns: < List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors, < the list will always have 1 element. For flat ShardedTensors the number of < elements varies depending on `dim` and on overlap, because flat < tensors must be contiguous. In particular the list can be empty. < """ < prepended_dim = dim + self.prepend_axis_num < local_length_along_dim = self.local_shape[dim] < < def _update_tuple(x, ind, val): < x = list(x) < x[ind] = val < return tuple(x) < < def _safe_div(x, y): < assert x % y == 0, (x, y) < return x // y < < # Decrease global shape and global offset by `length / local_length_along_dim` < assert ( < self.global_shape[prepended_dim] % local_length_along_dim == 0 < ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' < assert ( < self.global_offset[prepended_dim] % local_length_along_dim == 0 < ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' < global_shape = _update_tuple( < self.global_shape, < prepended_dim, < _safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim), < ) < global_offset = _update_tuple( < self.global_offset, < prepended_dim, < _safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim), < ) < < if self.flattened_range is None: < new_data = self.data.narrow(dim, start, length) < # always a single result tensor < return [ < replace( < self, < data=new_data, < local_shape=new_data.shape, < global_shape=global_shape, < global_offset=global_offset, < ) < ] < else: < if dim != 0: < raise CheckpointingException( < f'Narrowing along the first axis is supported for now only, got dim={dim}' < ) < < # If dim=0, we will always get 0 or 1 resulting tensor. < # If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1) < < # For on original flat ShardedTensor of local shape [3, 4] and < # flattened_range=slice(5, 10), < # the X signs mark the actual (flat) data in `self.data` < # notice 12 (3*4) total "virtual" elements, out of which 5 is actual data. < # flat original: [.....XXXXX..] < < # If we narrow to start=1, length=1 in the original local shape dimensions, < # the overlapping flat slice would be: < # narrow to: [....XXXX....] < # flat overlap: [.....XXX....] < < # Now `data` is flattened and sliced, so we must compute local_shape manually < local_shape = _update_tuple(self.local_shape, dim, length) < other_dims_volume = np.prod( < _update_tuple(local_shape, dim, 1) < ) # 4 in the example above < volume_before_split = other_dims_volume * start # 4 in the example above < volume_of_split = other_dims_volume * length # 4 in the example above < < flat_slice_start_shifted = ( < self.flattened_range.start - volume_before_split < ) # 5 - 4 = 1 in the example above < flat_slice_stop_shifted = ( < self.flattened_range.stop - volume_before_split < ) # 10 - 4 = 6 in the example above < < # Find an intersection of < # (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split) < < if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split: < return [] # no intersection < < # new_flattened_range = slice(1, 4) in the example above < new_flattened_range = slice( < max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split) < ) < # Apply the intersection to the flattened data tensor. < # Compute start and slice appropriate length < intersection_slice_start = ( < new_flattened_range.start - flat_slice_start_shifted < ) # 0 in the example above < new_data = self.data[ < intersection_slice_start : intersection_slice_start < + new_flattened_range.stop < - new_flattened_range.start < ] < < return [ < replace( < self, < data=new_data, < local_shape=local_shape, < global_shape=global_shape, < global_offset=global_offset, < flattened_range=new_flattened_range, < ) < ] < 521d352 < """Returns the original object.""" 568,573c399 < """returns a unique key for this object""" < return ( < f'{self.key}/shard_' < f'{".".join(map(str, self.global_offset))}_' < f'{".".join(map(str, self.global_shape))}' < ) --- > return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}' 580,590d405 < """Instantiates a ShardedObject from a unique key. < < Args: < unique_key: a string of the form < /shard__ < replica_id: indicates local object replication wrt. < local objects in different processes < < Returns: < a ShardedObject with data=None < """ 597,598c412 < # This is a backward-compatible fix. We don't know the last < # element of global shape so set it to -1. --- > # This is a backward-compatible fix. We don't know the last element of global shape so set it to -1. 603,606d416 < FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict] < FactoryMergeFn = Callable[[StateDict], torch.Tensor] < < 622,631c432,436 < data (torch.Tensor): original model parameter that will be further < transformed by this factory < build_fn (callable): function that transforms the original tensor < to a sharded state dict < merge_fn (callable): function that transforms loaded subtree back < into a single tensor (inverse of `build_fn`) < replica_id (ReplicaId): indicates factory replication wrt. < factories in different processes < flattened_range (slice, optional): indicates additional flattening < applied to the ShardedTensors produced by the factory --- > data (torch.Tensor): original model parameter that will be further transformed by this factory > build_fn (callable): function that transforms the original tensor to a sharded state dict > merge_fn (callable): function that transforms loaded subtree back into a single tensor (inverse of `build_fn`) > replica_id (ReplicaId): indicates factory replication wrt. factories in different processes > flattened_range (slice, optional): indicates additional flattening applied to the ShardedTensors produced by the factory 636,637c441,442 < build_fn: FactoryBuildFn < merge_fn: FactoryMergeFn --- > build_fn: Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict] > merge_fn: Callable[[StateDict], torch.Tensor] 642d446 < """Builds a ShardedStateDict from the original tensor""" 657,658c461 < sharded_state_dict (ShardedStateDict): state dict possibly < containing ShardedTensorFactory objects --- > sharded_state_dict (ShardedStateDict): state dict possibly containing ShardedTensorFactory objects 679,684c482,484 < x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) < with ShardedTensorFactory < as (possibly nested) values that define how to < merge objects from the `x1` state dict < key (Tuple[str, ...]): current key in a recursive call. < Used only for reporting meaningful errors --- > x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) with ShardedTensorFactory > as (possibly nested) values that define how to merge objects from the `x1` state dict > key (Tuple[str, ...]): current key in a recursive call. Used only for reporting meaningful errors 697,698c497 < f'Different dict keys encountered in `apply_factory_merges` ' < f'({x1.keys()} vs {x2.keys()})' --- > f'Different dict keys encountered in `apply_factory_merges` ({x1.keys()} vs {x2.keys()})' 704,707c503 < err_msg = ( < f'Cannot merge two lists with different lengths ' < f'({len(x1)} and {len(x2)}, encountered at key {key})' < ) --- > err_msg = f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at key {key})' 716,717c512 < f'Invalid dict key {k} non-integer type encountered ' < f'in a list-dict merge at level {key}' --- > f'Invalid dict key {k} non-integer type encountered in a list-dict merge at level {key}' 721,722c516 < f'Dict key {k} out of bound for list of length' < f'{len(x1)} (encountered at level {key})' --- > f'Dict key {k} out of bound for list of length {len(x1)} (encountered at level {key})' Binary files ./megatron/core/dist_checkpointing/__pycache__/core.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/core.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/dict_utils.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/dict_utils.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/mapping.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/mapping.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/optimizer.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/optimizer.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/serialization.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/serialization.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/utils.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/__pycache__/validation.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/__pycache__/validation.cpython-310.pyc differ diff -rN ./megatron/core/dist_checkpointing/serialization.py ../megatron-lm/megatron/core/dist_checkpointing/serialization.py 7,8c7 < Additionally, `load` expects the sharded state dict argument as a guidance for < loading the sharded tensors. --- > Additionally, `load` expects the sharded state dict argument as a guidance for loading the sharded tensors. 19c18 < from .dict_utils import extract_matching_values, merge --- > from .dict_utils import dict_list_map_inplace, extract_matching_values, merge 23a23 > ShardedTensorFactory, 24a25 > apply_factories, 27d27 < from .state_dict_transformation import load_preprocess, save_preprocess 38c38 < from .utils import extract_sharded_base --- > from .utils import extract_nonpersistent, extract_sharded_base 44a45 > validate_sharding_integrity, 81,84c82,83 < sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): < configures loading behavior for sharded tensors < common_strategy (LoadCommonStrategy, Tuple[str, int], optional): < configures loading behavior for common data --- > sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): configures loading behavior for sharded tensors > common_strategy (LoadCommonStrategy, Tuple[str, int], optional): configures loading behavior for common data 109,110c108,115 < sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( < sharded_state_dict --- > # Create a copy of sharded_state_dict as the passed in state dict may have > # references that prevent tensors from being deallocated > sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True) > > sh_ten_factories, _ = extract_matching_values( > sharded_state_dict, > lambda x: isinstance(x, ShardedTensorFactory), > return_lists_as_dicts=True, 111a117,123 > apply_factories(sharded_state_dict) > > # Data inside sh_ten_factories no longer needed so delete them to reduce memory usage > dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories) > # Non-persistent objects > nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict) > dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) 150,152c162 < merge(common_state_dict, loaded_state_dict) < < loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories) --- > loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories) 153a164 > merge(common_state_dict, loaded_state_dict) 191,192c202 < Defaults to None - in this case a default load strategy for a given checkpoint type < is used. --- > Defaults to None - in this case a default load strategy for a given checkpoint type is used. 195,196c205 < CkptShardedMetadata: flat state dict without data describing ShardedTensors < in the checkpoint --- > CkptShardedMetadata: flat state dict without data describing ShardedTensors in the checkpoint 226,227c235 < Defaults to None - in this case a default load strategy for a given checkpoint type < is used. --- > Defaults to None - in this case a default load strategy for a given checkpoint type is used. 229,230c237,238 < Defaults to None - in this case a default load strategy for a given checkpoint type is < used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects --- > Defaults to None - in this case a default load strategy for a given checkpoint type is used. > This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects 318,321c326,327 < sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): < configures sharded tensors saving behavior and backend < common_strategy (SaveCommonStrategy, Tuple[str, int], optional): < configures common data saving behavior and backend --- > sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): configures sharded tensors saving behavior and backend > common_strategy (SaveCommonStrategy, Tuple[str, int], optional): configures common data saving behavior and backend 362c368,370 < sharded_state_dict, state_dict = save_preprocess(sharded_state_dict, validate_access_integrity) --- > apply_factories(sharded_state_dict) > _, sharded_state_dict = extract_nonpersistent(sharded_state_dict) > sharded_state_dict, state_dict = extract_sharded_base(sharded_state_dict) 365a374,376 > if validate_access_integrity: > validate_sharding_integrity(determine_global_metadata(sharded_state_dict)[1]) > 398d408 < """Get default save sharded strategy.""" 405d414 < """Get default save common strategy.""" 410d418 < """Get default load sharded strategy.""" diff -rN ./megatron/core/dist_checkpointing/state_dict_transformation.py ../megatron-lm/megatron/core/dist_checkpointing/state_dict_transformation.py 1,253d0 < # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. < < """ Utilities for transforming state_dict, including a tensor-aware implementation.""" < < import logging < from time import time < from typing import Any, Optional < < import torch < < from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values < from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution < from .mapping import ( < ShardedObject, < ShardedStateDict, < ShardedTensor, < ShardedTensorFactory, < apply_factories, < apply_factory_merges, < ) < from .utils import ( < _sharded_object_id, < _sharded_tensor_shard_id, < extract_nonpersistent, < extract_sharded_base, < ) < from .validation import determine_global_metadata, validate_sharding_integrity < < logger = logging.getLogger(__name__) < < < def save_preprocess(sharded_state_dict: ShardedStateDict, validate_access_integrity: bool = True): < """Preprocesses the given state dictionary by applying factories, < discarding non-persistent data and extracting the common state dictionary. < Optionally, it can validate sharding integrity. < < Args: < sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed. < validate_access_integrity (bool): If True, triggers validation of sharding integrity. < < Returns: < Tuple[ShardedStateDict, dict]: < The preprocessed sharded state dictionary and the common state dictionary. < """ < apply_factories(sharded_state_dict) < _, sharded_state_dict = extract_nonpersistent(sharded_state_dict) < sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict) < if validate_access_integrity: < validate_sharding_integrity(determine_global_metadata(sharded_part)[1]) < return sharded_part, common_state_dict < < < def load_preprocess(sharded_state_dict: ShardedStateDict): < """Preprocesses the given state dictionary by applying factories < and extracting non-persistent data, without modifying the original dictionary. < < Args: < sharded_state_dict (ShardedStateDict): < The initial state dictionary to be processed (remains unchanged). < < Returns: < Tuple[ShardedStateDict, dict, dict]: < - A preprocessed copy of the sharded state dictionary. < - A dictionary containing non-persistent state data. < - A dictionary of `ShardedTensorFactory` instances. < """ < # Create a copy of sharded_state_dict as the passed in state dict may have < # references that prevent tensors from being deallocated < sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True) < < sh_ten_factories, _ = extract_matching_values( < sharded_state_dict, < lambda x: isinstance(x, ShardedTensorFactory), < return_lists_as_dicts=True, < ) < apply_factories(sharded_state_dict) < < # Data inside sh_ten_factories no longer needed so delete them to reduce memory usage < dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories) < # Non-persistent objects < nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict) < dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) < return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories < < < def prepare_state_dict_for_save( < sharded_state_dict: ShardedStateDict, < async_prepare: bool = False, < algo: str = 'atomic', < validate_access_integrity: bool = True, < parallelization_group: Optional[torch.distributed.ProcessGroup] = None, < to_cpu: bool = True, < ): < """Creates a tensor-aware state dictionary that can be saved using the Local Checkpoint Manager. < < Args: < sharded_state_dict (ShardedStateDict): The initial state dictionary. < async_prepare (bool): If True, enables asynchronous preparation. < algo (str): The algorithm used to create the tensor-aware state dictionary. < validate_access_integrity (bool): If True, validates sharding integrity. < parallelization_group (torch.distributed.ProcessGroup): < The process group used for exchanges to avoid duplications. < to_cpu (bool): If True, moves all tensors from device to CPU. < < Returns: < ShardedStateDict: The tensor-aware state dictionary. < """ < < _start = time() < < if async_prepare: < raise NotImplementedError('Async state_dict preparation is not yet implemented') < if algo != 'atomic' and algo != 'fully_parallel': < raise NotImplementedError( < 'Only "atomic" and "fully_parallel" sharding algorithms are supported.' < ) < fully_parallel = algo == 'fully_parallel' < < sharded_part, common_state_dict = save_preprocess(sharded_state_dict, validate_access_integrity) < sharded_tensors = [] < sharded_objects = [] < for sh_base in nested_values(sharded_part): < if isinstance(sh_base, ShardedTensor): < sharded_tensors.append(sh_base) < else: < assert isinstance(sh_base, ShardedObject) < sharded_objects.append(sh_base) < if fully_parallel: < shard_to_saving_rank, _, shard_to_metadata = determine_main_replica_uniform_distribution( < sharded_part, parallelization_group, True < ) < < raw_tensors, raw_objects = {}, {} < for ten in sharded_tensors: < shard_id = _sharded_tensor_shard_id(ten) < if not fully_parallel or shard_to_saving_rank[shard_id] == torch.distributed.get_rank(): < # TODO cover creating copies on host in CheckpointManager.save() < if to_cpu: < raw_tensors[shard_id] = ten.data.to("cpu", non_blocking=True) < else: < raw_tensors[shard_id] = ten.data < ten.data = None < for obj in sharded_objects: < raw_objects[_sharded_object_id(obj)] = obj.data < obj.data = None < < logger.debug(f'prepare_state_dict_for_save took {time() - _start}') < < state_dict_for_save = { < 'raw_tensors': raw_tensors, < 'raw_objects': raw_objects, < 'common': common_state_dict, < 'sharded_state_dict': sharded_part, < } < if fully_parallel: < state_dict_for_save['shard_to_rank'] = shard_to_saving_rank < state_dict_for_save['shard_to_metadata'] = shard_to_metadata < return state_dict_for_save < < < def recreate_state_dict_after_load( < sharded_state_dict: ShardedStateDict, < loaded_state_dict: ShardedStateDict, < algo: str = 'atomic', < exchange_algo: str = 'broadcast', < validate_access_integrity: bool = True, < parallelization_group: Optional[torch.distributed.ProcessGroup] = None, < ): < """Creates a final sharded state dictionary from a tensor-aware state dictionary. < < Args: < sharded_state_dict (ShardedStateDict): < The initial sharded state dictionary generated from the model. < loaded_state_dict (ShardedStateDict): < Tensor-aware state dictionary used to fill in missing data in the sharded state. < algo (str): The algorithm used to reconstruct the state dictionary < from the tensor-aware state dictionary. < exchange_algo (str): The algorithm used for tensor exchanges during retrieval. < validate_access_integrity (bool): If True, performs validation of sharding integrity. < parallelization_group (torch.distributed.ProcessGroup): < The process group used for efficient exchanges during retrieval. < < Returns: < ShardedStateDict: The finalized sharded state dictionary. < """ < < if algo != 'atomic' and algo != 'fully_parallel': < raise NotImplementedError( < 'Only "atomic" and "fully_parallel" sharding algorithms are supported.' < ) < fully_parallel = algo == 'fully_parallel' < < # __adding__ common part < recreated_state_dict, _ = extract_matching_values(loaded_state_dict["common"], lambda x: True) < < if not sharded_state_dict: < return recreated_state_dict < # TODO validate laoded_state_dict["sharded_state_dict"] and sharded_state_dict are compatible < < sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( < sharded_state_dict < ) < # __adding__ nonpersistent part < merge(recreated_state_dict, nonpersistent_state_dict) < < sharded_part, _ = extract_sharded_base(sharded_state_dict) < if validate_access_integrity: < validate_sharding_integrity(determine_global_metadata(sharded_part)[1]) < < # load sharded tensors and sharded objects to sharded_part < loaded_tensors = loaded_state_dict['raw_tensors'] < # TODO cover restoring the original device (H2D) in CheckpointManager.load() < for k, v in loaded_tensors.items(): < loaded_tensors[k] = v.cuda() # H2D < if fully_parallel: < distribution = ( < loaded_state_dict['shard_to_rank'], < None, < loaded_state_dict['shard_to_metadata'], < ) < unloaded_shards = {} < for sh_base in nested_values(sharded_part): < if isinstance(sh_base, ShardedTensor): < shard_id = _sharded_tensor_shard_id(sh_base) < if shard_id not in loaded_tensors: < unloaded_shards[shard_id] = sh_base < loaded_tensors = exchange_by_distribution( < loaded_tensors, unloaded_shards, distribution, parallelization_group, exchange_algo < ) < loaded_objects = loaded_state_dict['raw_objects'] < < def load_sharded_base(x: Any): < if isinstance(x, ShardedTensor): < shard_id = _sharded_tensor_shard_id(x) < if shard_id not in loaded_tensors: < raise Exception( < 'The current local checkpoint implementation assumes' < 'consistent tensor sharding during load and save operations.' < f'However, the expected shard {x} (ID: {shard_id})' < f'was not found in the checkpoint. (IDs: {loaded_tensors.keys()})' < ) < x = loaded_tensors[shard_id] < if isinstance(x, ShardedObject): < object_id = _sharded_object_id(x) < assert object_id in loaded_objects, (x, object_id, loaded_objects.keys()) < x = loaded_objects[object_id] < return x < < dict_list_map_inplace(load_sharded_base, sharded_part) < sharded_part = apply_factory_merges(sharded_part, sh_ten_factories) < # __adding__ sharded_part < merge(recreated_state_dict, sharded_part) < return recreated_state_dict diff -rN ./megatron/core/dist_checkpointing/strategies/base.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/base.py 9c9 < from typing import Any, DefaultDict, Union --- > from typing import Any, DefaultDict 16,17d15 < """Specifies save vs load and sharded vs common action.""" < 23a22 > _import_trigger = None 34,39c33,34 < from .tensorstore import register_default_tensorstore_strategies < < register_default_tensorstore_strategies() < from .zarr import register_default_zarr_strategies < < register_default_zarr_strategies() --- > from .tensorstore import _import_trigger > from .zarr import _import_trigger 42,44c37 < from .torch import register_default_torch_strategies < < register_default_torch_strategies() --- > from .torch import _import_trigger 58,74d50 < def register_default_strategy( < action: StrategyAction, < backend: str, < version: int, < strategy: Union['SaveStrategyBase', 'LoadStrategyBase'], < ): < """Adds a given strategy to the registry of default strategies. < < Args: < action (StrategyAction): specifies save/load and sharded/common < backend (str): backend that the strategy becomes a default for < version (int): version that the strategy becomes a default for < strategy (SaveStrategyBase, LoadStrategyBase): strategy to register < """ < default_strategies[action.value][(backend, version)] = strategy < < 80,81c56 < def check_backend_compatibility(self, loaded_backend): < """Verifies if this strategy is compatible with `loaded_backend`.""" --- > def check_backend_compatibility(self, loaded_version): 86d60 < """Verifies if this strategy is compatible with `loaded_version`.""" 117d90 < """Load common part of the checkpoint.""" 124d96 < """Load sharded objects from the checkpoint.""" 128d99 < """Load just the metadata from the checkpoint.""" 139d109 < """Load the sharded part of the checkpoint.""" 178d147 < """Save common part of the state dict.""" 184d152 < """Save sharded objects from the state dict.""" 193d160 < """Save the sharded part of the state dict.""" diff -rN ./megatron/core/dist_checkpointing/strategies/common.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/common.py 15c15 < register_default_strategy, --- > default_strategies, 21a22,23 > _import_trigger = None > 27,34d28 < def register_default_common_strategies(): < """Register default common strategies.""" < register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy()) < register_default_strategy( < StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1) < ) < < 36,37d29 < """Common save strategy leveraging native torch save/load.""" < 39d30 < """Save common part of the state dict.""" 46c37 < """Save sharded objects from the state dict.""" --- > 54d44 < """This strategy can handle ShardedObjects.""" 59,60d48 < """Common load strategy leveraging native torch save/load.""" < 150d137 < """This strategy can handle ShardedObjects.""" 157a145,150 > > > default_strategies[StrategyAction.LOAD_COMMON.value][('torch', 1)] = TorchCommonLoadStrategy() > default_strategies[StrategyAction.SAVE_COMMON.value][('torch', 1)] = TorchCommonSaveStrategy( > 'torch', 1 > ) diff -rN ./megatron/core/dist_checkpointing/strategies/fully_parallel.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/fully_parallel.py 2a3,5 > from collections import defaultdict > from functools import reduce > from itertools import zip_longest 5c8 < from typing import Dict, Optional, Tuple --- > from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast 6a10 > import numpy as np 18,22d21 < from megatron.core.dist_checkpointing.exchange_utils import ( < ShardDistribution, < determine_main_replica_uniform_distribution, < exchange_by_distribution, < ) 29d27 < from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId 37a36,61 > # _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor > # attributes: key (str), global_offset (tuple) and flattened_range (optional tuple) > _ShardId = Tuple[str, tuple, Optional[tuple]] > > > class SaveLoadDistribution(NamedTuple): > """Represents a save or load distribution of ShardedTensors. > > Given distribution is valid only for a specific parallelization group, > which is implicit here (not referenced by this class). > > Args: > main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold > the main replica for a given shard > shards_in_this_group (Set[_ShardId]): which shards have a main replica > in this parallelization group > shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor > identifier to the original ShardedTensor > > """ > > main_rank_for_shard: Dict[_ShardId, int] > shards_in_this_group: Set[_ShardId] > shard_to_metadata: Dict[_ShardId, ShardedTensor] > > 74c98 < self.cached_distribution: Optional[ShardDistribution] = None --- > self.cached_distribution: Optional[SaveLoadDistribution] = None 172c196 < self.cached_distribution: Optional[ShardDistribution] = None --- > self.cached_distribution: Optional[SaveLoadDistribution] = None 237,242c261,271 < all_loaded_tensors = exchange_by_distribution( < loaded_tensors, < unloaded_shards, < precomputed_distribution, < self.parallelization_group, < self.exchange_algo, --- > if self.exchange_algo == 'gather_object': > exchange_fn = self.exchange_loaded_tensors_gather_object > elif self.exchange_algo == 'gather_rounds': > exchange_fn = self.exchange_loaded_tensors_gather_rounds > elif self.exchange_algo == 'broadcast': > exchange_fn = self.exchange_loaded_tensors_broadcast > else: > raise NotImplementedError(f'Unrecognized gather algorithm: {self.exchange_algo}') > > all_loaded_tensors = exchange_fn( > loaded_tensors, unloaded_shards, precomputed_distribution, self.parallelization_group 307c336 < ) -> Optional[ShardDistribution]: --- > ) -> Optional[SaveLoadDistribution]: 323c352 < ShardDistribution (optional): the computed loading distribution --- > SaveLoadDistribution (optional): the computed loading distribution 341a371,625 > def exchange_loaded_tensors_gather_object( > self, > loaded_tensors: Dict[_ShardId, torch.Tensor], > unloaded_shards: Dict[_ShardId, ShardedTensor], > precomputed_distribution: SaveLoadDistribution, > parallelization_group: Optional[torch.distributed.ProcessGroup] = None, > ) -> Dict[_ShardId, torch.Tensor]: > """Exchange the tensors loaded by different ranks with a simple all_gather_object call. > > This version can be used for debugging purposes do to its simplistic > implementation. Shouldn't be used if performance is important. > > Args: > loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor > shard ids to tensors already loaded by this rank. > unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor > shard ids to ShardedTensors that aren't loaded yet. > precomputed_distribution (SaveLoadDistribution): uniform load distribution > parallelization_group (ProcessGroup, optional): process group used for load > distribution. Tensors will be exchanged within this group > > Returns: > Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors > needed by this rank to load a given state dict. Includes > previously loaded tensors (from `loaded_tensors` input) > > """ > all_loaded_tensors_list = [None] * torch.distributed.get_world_size( > group=parallelization_group > ) > torch.distributed.all_gather_object( > all_loaded_tensors_list, loaded_tensors, group=parallelization_group > ) > all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list) > all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list) > > # Error checks > if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)): > err_msg = 'Duplicate shard ids loaded by different ranks' > if torch.distributed.get_rank() == 0: > logger.error( > f'{err_msg}. Shards ids by rank: {[lt.keys() for lt in all_loaded_tensors_list]}' > ) > raise CheckpointingException(err_msg) > > return all_loaded_tensors > > @torch.no_grad() > def exchange_loaded_tensors_gather_rounds( > self, > loaded_tensors: Dict[_ShardId, torch.Tensor], > unloaded_shards: Dict[_ShardId, ShardedTensor], > precomputed_distribution: SaveLoadDistribution = None, > parallelization_group: Optional[torch.distributed.ProcessGroup] = None, > ) -> Dict[_ShardId, torch.Tensor]: > """Exchange the tensors loaded by different ranks with several all_gather calls. > > Groups tensors by dtype, divide tensors that will be exchanged into rounds > and execute all_gather for tensors from each round. > > Note: the loading is distributed across ranks based on total loaded size > in bytes, so there is no guarantee that number of rounds needed for each > rank will be similar, which might result in a lot of almost empty > all_gathers. The solution would be to group all tensors into a one > bytes tensor and do a single all_gather (with similarly sized messages). > > Args: > loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor > shard ids to tensors already loaded by this rank. > unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor > shard ids to ShardedTensors that aren't loaded yet. > precomputed_distribution (SaveLoadDistribution): uniform load distribution > parallelization_group (ProcessGroup, optional): process group used for load > distribution. Tensors will be exchanged within this group > > Returns: > Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors > needed by this rank to load a given state dict. Includes > previously loaded tensors (from `loaded_tensors` input) > """ > shard_to_saving_rank, _, shard_to_metadata = precomputed_distribution > local_rank = torch.distributed.get_rank(group=self.parallelization_group) > > all_loaded_tensors = dict(loaded_tensors) > > # Group by dtype so that we all_gather tensors of the same dtype > for dtype in sorted( > set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str > ): > > start = time() > # shards_by_rank maps rank to tensors loaded by this rank > shards_by_rank: List[List[torch.Tensor]] = [ > [] for _ in range(torch.distributed.get_world_size(group=parallelization_group)) > ] > for shard_id, rank in shard_to_saving_rank.items(): > if shard_to_metadata[shard_id].dtype == dtype: > shards_by_rank[rank].append(shard_id) > > # Transpose `shards_by_rank` to form exchange rounds > shards_by_round = zip_longest(*shards_by_rank, fillvalue=None) > for round_idx, round_shard_ids in enumerate(shards_by_round): > round_tensors = [] > orig_devices = {} > for rank, shard_id in enumerate(round_shard_ids): > if shard_id is None: > # if no more useful data, the given rank will exchange empty tensor > local_ten = torch.empty(0, dtype=dtype, device='cuda') > orig_device = None > else: > assert isinstance(shard_id, tuple), type(shard_id) > if rank == local_rank: > assert shard_id in all_loaded_tensors, ( > shard_id, > all_loaded_tensors.keys(), > ) > orig_device = all_loaded_tensors[shard_id] > all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda() > local_ten = all_loaded_tensors[shard_id] > else: > local_ten, orig_device = self._get_empty_tensor_for_exchange( > shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors > ) > round_tensors.append(local_ten) > if orig_device is not None: > orig_devices[shard_id] = orig_device > > torch.distributed.all_gather( > list(round_tensors), > round_tensors[local_rank], > group=self.parallelization_group, > async_op=False, > ) > > # Move tensors back to CPU if originally was on CPU > for shard_id, orig_device in orig_devices.items(): > all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device) > > del round_tensors # remove tensor references > > end = time() > if torch.distributed.get_rank() == 0: > logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s') > > return all_loaded_tensors > > @torch.no_grad() > def exchange_loaded_tensors_broadcast( > self, > loaded_tensors: Dict[_ShardId, torch.Tensor], > unloaded_shards: Dict[_ShardId, ShardedTensor], > precomputed_distribution: SaveLoadDistribution = None, > parallelization_group: Optional[torch.distributed.ProcessGroup] = None, > ) -> Dict[_ShardId, torch.Tensor]: > """Exchange the tensors loaded by different ranks by a series of broadcasts. > > For each rank for each loaded tensor do a broadcast to the whole group. > A reasonable tradeoff in terms of performance and simplicity. > > Args: > loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor > shard ids to tensors already loaded by this rank. > unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor > shard ids to ShardedTensors that aren't loaded yet. > precomputed_distribution (SaveLoadDistribution): uniform load distribution > parallelization_group (ProcessGroup, optional): process group used for load > distribution. Tensors will be exchanged within this group > > Returns: > Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors > needed by this rank to load a given state dict. Includes > previously loaded tensors (from `loaded_tensors` input) > """ > shard_to_saving_rank, _, shard_to_metadata = precomputed_distribution > local_rank = torch.distributed.get_rank(group=self.parallelization_group) > > all_loaded_tensors = dict(loaded_tensors) > > start = time() > > for idx, (shard_id, rank) in enumerate(shard_to_saving_rank.items()): > if rank == local_rank: > assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) > orig_device = all_loaded_tensors[shard_id].device > local_ten = all_loaded_tensors[shard_id].cuda() > else: > local_ten, orig_device = self._get_empty_tensor_for_exchange( > shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors > ) > > global_src_rank = torch.distributed.get_global_rank(parallelization_group, rank) > # We can do async_op=True only if there is no CPU-copy follow-up > torch.distributed.broadcast( > local_ten, > src=global_src_rank, > group=parallelization_group, > async_op=orig_device is None, > ) > # Move tensor back to CPU if originally was on CPU > if orig_device is not None: > all_loaded_tensors[shard_id] = local_ten.to(orig_device) > del local_ten > > end = time() > if torch.distributed.get_rank() == 0: > logger.debug(f'exchange broadcast schedule took {end - start}s') > > return all_loaded_tensors > > def _get_empty_tensor_for_exchange( > self, > shard_id: _ShardId, > needed_shards: Dict[_ShardId, ShardedTensor], > unneeded_shards: Dict[_ShardId, ShardedTensor], > loaded_tensors: Dict[_ShardId, torch.Tensor], > ) -> Tuple[torch.Tensor, Optional[torch.device]]: > """Determines the empty tensor to use for exchange. > > If shard_id is needed by this rank, it will be in the `unloaded_shards`. > Otherwise, the metadata for this tensor can be found in `shard_to_metadata` > > Args: > shard_id (_ShardId): shard_id that will be exchanged > needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids > to metadata for shards needed by this rank > unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids > to metadata for shards that can be discarded after exchange > loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors > are placed in > > Returns: > Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged, > and the device of the original state dict tensor (if there was any) > """ > local_unloaded_sh_ten = needed_shards.get(shard_id) > if local_unloaded_sh_ten is None: > orig_device = None # this tensor will be discarded anyway > sh_ten = unneeded_shards[shard_id] > if sh_ten.data is None: > sh_ten.init_data('cuda') > tensor = sh_ten.data > sh_ten.data = None # won't be used. free memory > else: > tensor = sh_ten.data > if tensor.device.type == 'cpu': > tensor = torch.empty_like(tensor, device='cuda') > else: > local_unloaded_sh_ten.init_data('cuda') > orig_device = local_unloaded_sh_ten.data.device > tensor = local_unloaded_sh_ten.data > if tensor.device.type == 'cpu': > tensor = torch.empty_like(tensor, device='cuda') > loaded_tensors[shard_id] = tensor > return tensor, orig_device > 386a671,764 > def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId: > """Unique id of the sharded tensor data. > > Should yield the same value for same data replicated on different ranks. > > Args: > sharded_tensor (ShardedTensor): sharded tensor representing the data shard > > Returns (tuple): unique id of a data shard > """ > f_range = sharded_tensor.flattened_range > return ( > sharded_tensor.key, > sharded_tensor.global_offset, > None if f_range is None else (f_range.start, f_range.stop), > ) > > > def _shard_size(sh_ten: ShardedTensor): > """Returns size in bytes of a given sharded tensor.""" > if sh_ten.flattened_range is None: > numel = np.product(sh_ten.local_shape) > else: > numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start > return numel * torch._utils._element_size(sh_ten.dtype) > > > def determine_main_replica_uniform_distribution( > sharded_state_dict: ShardedStateDict, > parallelization_group: torch.distributed.ProcessGroup, > is_loading: bool = False, > ) -> Optional[SaveLoadDistribution]: > """Computes the save distribution. > > Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution` > which applies the computed save distribution. > > We rely on the fact that the assignment algorithm is deterministic on all ranks, > so there is no extra communication needed after metadata exchange. > > Args: > sharded_state_dict (ShardedStateDict): state dict to compute the distribution of > parallelization_group (ProcessGroup): distribution will be computed > within this process group > is_loading (bool, optional): whether the distribution is for loading or saving. > For loading, even non-main replicas must be loaded by this parallelization > group. Defaults to False. > > Returns (SaveLoadDistribution, optional): distribution that can be used to apply the > parallelization. Returns None if the process_group is trivial (1 rank) > > """ > group_size = torch.distributed.get_world_size(group=parallelization_group) > if group_size <= 1: > return > local_shards = list( > sh_base > for sh_base in nested_values(sharded_state_dict) > if isinstance(sh_base, ShardedTensor) > ) > local_shards_no_data = [ten.without_data() for ten in local_shards] > > all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group) > torch.distributed.all_gather_object( > all_shards, local_shards_no_data, group=parallelization_group > ) > > shard_to_ranks = defaultdict(list) > shard_to_size = {} > shard_to_metadata = {} > shards_saved_by_this_parallelization_group: Set[_ShardId] = set() > for rank, rank_shards in enumerate(all_shards): > for sh_ten in rank_shards: > shard_id = _sharded_tensor_shard_id(sh_ten) > shard_to_ranks[shard_id].append(rank) > if shard_id not in shard_to_size: > shard_to_size[shard_id] = _shard_size(sh_ten) > shard_to_metadata[shard_id] = sh_ten > if is_main_replica(sh_ten.replica_id) or is_loading: > shards_saved_by_this_parallelization_group.add(shard_id) > > shard_to_ranks = { > k: v for k, v in shard_to_ranks.items() if k in shards_saved_by_this_parallelization_group > } > > shard_to_saving_rank = distribute_shards_to_ranks( > shard_to_ranks, shard_to_size, len(all_shards) > ) > > return SaveLoadDistribution( > shard_to_saving_rank, shards_saved_by_this_parallelization_group, shard_to_metadata > ) > > 390c768 < precomputed_distribution: Optional[ShardDistribution], --- > precomputed_distribution: Optional[SaveLoadDistribution], 402c780 < precomputed_distribution (ShardDistribution): distribution computed with --- > precomputed_distribution (SaveLoadDistribution): distribution computed with 439a818,865 > > > T = TypeVar('T') > > > def distribute_shards_to_ranks( > shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int > ) -> Dict[T, int]: > """Computes uniform distribution of workload across ranks, based on sizes. > > Currently, the assignment is greedy, based on: > 1. Firstly, the coverage of each shard > (how many ranks the shard is available on; lower coverage is assigned first) > 2. Secondly, the size of each shard (larger size is assigned first) > 3. Finally, shard id for differentiation. > > Third step is added because we rely on the fact that the assignment is deterministic on all ranks. > > Args: > shard_to_ranks (Dict[T, List[int]]): mapping which tells which rank have access to which shards > shard_to_size (Dict[T, int]): sizes of each shard > num_ranks (int): number of ranks in the parallelization group > > Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work > to achieve maximal uniformity) > """ > shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()} > shard_to_saving_rank = {} > rank_sizes = [(0, rank) for rank in range(num_ranks)] > > # start from tensors with lowest coverage, then go by tensor size from largest (hence minus size) > for shard_id, shard_ranks in sorted( > shard_to_ranks.items(), > key=lambda sh_id_ranks: ( > len(sh_id_ranks[1]), > -shard_to_size[sh_id_ranks[0]], > sh_id_ranks[0], > ), > ): > # assign greedily to the least occupied rank > size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks) > > shard_to_saving_rank[shard_id] = rank > rank_sizes[rank] = (size + shard_to_size[shard_id], rank) > > logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}') > > return shard_to_saving_rank diff -rN ./megatron/core/dist_checkpointing/strategies/__init__.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/__init__.py 4d3 < from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies 6,7c5,9 < # We load "common" strategies by default to be always available < register_default_common_strategies() --- > # We mock imports to populate the `default_strategies` objects. > # Since they are defined in base but populated in common, we have to mock > # import both modules. > from megatron.core.dist_checkpointing.strategies.base import _import_trigger > from megatron.core.dist_checkpointing.strategies.common import _import_trigger Binary files ./megatron/core/dist_checkpointing/strategies/__pycache__/async_utils.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/strategies/__pycache__/async_utils.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/strategies/__pycache__/base.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/strategies/__pycache__/base.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/strategies/__pycache__/common.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/strategies/__pycache__/common.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/strategies/__pycache__/fully_parallel.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/strategies/__pycache__/fully_parallel.cpython-310.pyc differ Binary files ./megatron/core/dist_checkpointing/strategies/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/dist_checkpointing/strategies/__pycache__/__init__.cpython-310.pyc differ diff -rN ./megatron/core/dist_checkpointing/strategies/state_dict_saver.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/state_dict_saver.py 87,89c87 < # PyTorch 2.4 introduced additional `metadata` argument, < # we have to reference `is_coordinator` args by name < planner.set_up_planner(state_dict, is_coordinator=dist_wrapper.is_coordinator) --- > planner.set_up_planner(state_dict, dist_wrapper.is_coordinator) diff -rN ./megatron/core/dist_checkpointing/strategies/tensorstore.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/tensorstore.py 16,19c16,21 < from .base import LoadShardedStrategy, StrategyAction, register_default_strategy < from .zarr import load_zarr_based_sharded_metadata, postprocess_numpy_array < < logger = getLogger(__name__) --- > from .base import LoadShardedStrategy, StrategyAction, default_strategies > from .zarr import ( > load_zarr_based_sharded_metadata, > numpy_to_torch_dtype_dict, > postprocess_numpy_array, > ) 20a23 > _import_trigger = None 22,26c25 < def register_default_tensorstore_strategies(): < """Register default strategies leveraging tensorstore.""" < register_default_strategy( < StrategyAction.LOAD_SHARDED, 'zarr', 1, TensorStoreLoadShardedStrategy() < ) --- > logger = getLogger(__name__) 30,31d28 < """Load strategy for Zarr backend using `tensorstore` for loading.""" < 64,65d60 < """Intersects the global slice with the actual shape (prevent overflow).""" < 128a124,128 > > > default_strategies[StrategyAction.LOAD_SHARDED.value][ > ('zarr', 1) > ] = TensorStoreLoadShardedStrategy() diff -rN ./megatron/core/dist_checkpointing/strategies/torch.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/torch.py 3a4 > import dataclasses 4a6,7 > import itertools > import math 11a15 > import numpy as np 13c17 < from packaging.version import Version as PkgVersion --- > from pkg_resources import packaging 14a19 > from torch.distributed._shard._utils import narrow_tensor_by_index 18d22 < from torch.distributed._tensor import DTensor 24a29 > LoadPlanner, 32a38 > from torch.distributed.checkpoint.default_planner import create_default_local_save_plan 33a40 > from torch.distributed.checkpoint.planner import LoadItemType 34a42 > from torch.futures import Future 37c45 < from ..dict_utils import nested_values --- > from ..dict_utils import extract_matching_values, nested_values 42a51 > ShardedTensorFactory, 43a53,54 > apply_factories, > apply_factory_merges, 47,52c58 < from .base import ( < AsyncSaveShardedStrategy, < LoadShardedStrategy, < StrategyAction, < register_default_strategy, < ) --- > from .base import AsyncSaveShardedStrategy, LoadShardedStrategy, StrategyAction, default_strategies 64,65d69 < if not torch.cuda.is_available(): < raise ImportError 72,81c76 < < def register_default_torch_strategies(): < """Register default strategies related to PyT Distributed backend.""" < register_default_strategy( < StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy() < ) < register_default_strategy( < StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1) < ) < --- > _import_trigger = None 119,122c114,116 < On high-level, this function follows the logic of < torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. < Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) < as attributes for further restoration in `_unwrap_pyt_sharded_tensor`. --- > On high-level, this function follows the logic of torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. > Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) as attributes > for further restoration in `_unwrap_pyt_sharded_tensor`. 233c227 < for fragment_offsets in product(*map(range, some_sh_ten.axis_fragmentations)): --- > for fragment_offsets in itertools.product(*map(range, some_sh_ten.axis_fragmentations)): 253d246 < # pylint: disable=line-too-long 281,282c274 < # Store MCore related data as PyTShardedTensor attribute. < # This won't be stored in the checkpoint, only for runtime purposes --- > # Store MCore related data as PyTShardedTensor attribute. This won't be stored in the checkpoint, only for runtime purposes 295,296c287 < """Convert state dict with ShardedTensors and ShardedObjects < to state dict compatible with PyT Dist format. --- > """Turn state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format. 382,383c373 < """Group ShardedBase objects by keys and < return mappings required for recreating the original dict.""" --- > """Group ShardedBase objects by keys and return mappings required for recreating the original dict.""" 428,429d417 < """SavePlan with MCore specific data.""" < 451,453c439,440 < # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings < # during saving. < if PkgVersion(torch.__version__) <= PkgVersion("2.2"): --- > # `dedup_replicated_tensors` was deprecated in 2.3 - this avoids tons of warnings during saving > if packaging.version.Version(torch.__version__) <= packaging.version.Version("2.2"): 459,477c446,453 < """Adds IOBytes write request on non-coordinator ranks.""" < < # NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because < # some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container) < # add iobytes request only on coordinator ranks and some alpha versions < # (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container) < # add those requests on all ranks. We inline a simplified version of this method below. < write_items = [] < for fqn, obj in self.state_dict.items(): < assert not isinstance( < obj, DTensor < ) # translation from MCore ShardedTensors shouldn't result in DTensors < # Create write requests for tensor and bytes values. < # For MCore, these should be already non-duplicates. < write_items += _create_write_items(fqn, obj) < < self.plan = MCoreSavePlan( < items=write_items, < planner_data=self.mappings, --- > plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) > self._add_non_coordinator_iobytes_request(plan) > if self.flatten_state_dict: > plan = dataclasses.replace(plan, planner_data=self.mappings) > plan = MCoreSavePlan( > items=plan.items, > storage_data=plan.storage_data, > planner_data=plan.planner_data, 483a460,461 > self.plan = plan > 487d464 < """Merges MCore data for all plans.""" 491a469,475 > def _add_non_coordinator_iobytes_request(self, plan): > if self.is_coordinator: > return > for fqn, obj in self.state_dict.items(): > if isinstance(obj, io.BytesIO): > plan.items.extend(_create_write_items(fqn, obj)) > 493d476 < """Make no transformations - bytes objects are already serialized.""" 527d509 < """Runs additional shapes validation.""" 599,600c581 < # cached outcome of `SavePlan.prepare_global_plan`, < # which aggregates local plans from all ranks --- > # cached outcome of `SavePlan.prepare_global_plan`, which aggregates local plans from all ranks 604,605c585 < # Cached global metadata, only `coordinator` for dist-ckpt holds < # if central plans are consistent over iters --- > # Cached global metadata, only `coordinator` for dist-ckpt holds if central plans are consistent over iters 616c596 < """Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format. --- > """Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format. 692,702d671 < """Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load. < < Args: < sharded_state_dict (ShardedStateDict): sharded state dict to load < checkpoint_dir (Path): checkpoint directory < < Returns: < Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every < N-D flattened tensor from the sharded_state_dict to its original global shape < as stored in `mcore_data` in the checkpoint. < """ 714,715c683 < f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} ' < f'in checkpoint metadata: {ckpt_metadata.mcore_data}' --- > f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} in checkpoint metadata: {ckpt_metadata.mcore_data}' 728c696 < """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. --- > """Translates MCore ShardedTensors to PyT ShardedTensors and loads from PyT Distributed format. 835a804,811 > > > default_strategies[StrategyAction.LOAD_SHARDED.value][ > ('torch_dist', 1) > ] = TorchDistLoadShardedStrategy() > default_strategies[StrategyAction.SAVE_SHARDED.value][('torch_dist', 1)] = ( > TorchDistSaveShardedStrategy('torch_dist', 1) > ) diff -rN ./megatron/core/dist_checkpointing/strategies/zarr.py ../megatron-lm/megatron/core/dist_checkpointing/strategies/zarr.py 5a6 > import threading 18,23c19 < from .base import ( < LoadShardedStrategy, < SaveShardedStrategy, < StrategyAction, < register_default_strategy, < ) --- > from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies 45,46c41 < # Register a bfloat16 type with this import < import tensorstore # pylint: disable=unused-import --- > import tensorstore 54,55c49 < logger = getLogger(__name__) < --- > _import_trigger = None 57,61c51 < def register_default_zarr_strategies(): < """Register default strategies related to Zarr backend.""" < register_default_strategy( < StrategyAction.SAVE_SHARDED, 'zarr', 1, ZarrSaveShardedStrategy('zarr', 1) < ) --- > logger = getLogger(__name__) 65,66d54 < """Save strategy for Zarr backend.""" < 89,90c77 < b) is main replica but not the first chunk, < opens the arrays created in (a) (possibly by other process) --- > b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process) 94,95c81 < sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank < that will be saved to checkpoint --- > sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint 176,177d161 < """Load strategy for the Zarr backend.""" < 229d212 < """Turn numpy array to torch tensor.""" 257d239 < """Apply flattened range to a tensor.""" 262d243 < """Pad tensor to the expected shape.""" 274,277c255,257 < assert False, ( < f'Expected shape ({exp_sh}) smaller than actual ({x_sh})' < f' for {repr(expected_sharded_ten)}' < ) --- > assert ( > False > ), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}' 321a302,307 > > > # default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy() > default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy( > 'zarr', 1 > ) diff -rN ./megatron/core/dist_checkpointing/utils.py ../megatron-lm/megatron/core/dist_checkpointing/utils.py 5c5 < from typing import Dict, Optional, Tuple --- > from typing import Dict, Tuple 18,52d17 < # _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor < # attributes: key (str), global_offset (tuple) and flattened_range (optional tuple) < _ShardId = Tuple[str, tuple, Optional[tuple]] < < < def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId: < """Unique id of the sharded tensor data. < < Should yield the same value for same data replicated on different ranks. < < Args: < sharded_tensor (ShardedTensor): sharded tensor representing the data shard < < Returns (tuple): unique id of a data shard < """ < f_range = sharded_tensor.flattened_range < return ( < sharded_tensor.key, < sharded_tensor.global_offset, < None if f_range is None else (f_range.start, f_range.stop), < ) < < < def _sharded_object_id(sharded_object: ShardedObject) -> _ShardId: < """Unique id of the sharded object data. < < Should yield the same value for same data replicated on different ranks. < < Args: < sharded_object (ShardedObject): sharded object representing the data shard < < Returns (tuple): unique id of a data shard < """ < return (sharded_object.key, sharded_object.global_offset, sharded_object.global_shape) < 57,58c22 < """Extract a dict consisting of only ShardedTensor objects < from a given state dict with any objects. --- > """Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects. 66,67c30 < - state dict with all objects other than ShardedTensor < (keeping the original state dict structure) --- > - state dict with all objects other than ShardedTensor (keeping the original state dict structure) 75,76c38 < """Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects < from a given state dict with any objects. --- > """Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects. 79,80c41 < sharded_state_dict: < state dict possibly containing ShardedTensor and ShardedTensorFactory objects --- > sharded_state_dict: state dict possibly containing ShardedTensor and ShardedTensorFactory objects 84,85c45 < - state dict with all ShardedTensor and ShardedTensorFactory < (keeping the original state dict structure) --- > - state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure) 96,97c56,57 < """Extract a dict consisting of only ShardedTensor, ShardedTensorFactory < and LocalNonpersistentObject objects from a given state dict with any objects. --- > """Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject > objects from a given state dict with any objects. 100,101c60 < sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory < and LocalNonpersistentObject objects --- > sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject objects 105,106c64 < - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject < (keeping the original state dict structure) --- > - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject (keeping the original state dict structure) 118,127d75 < """Extract a dict consisting of only ShardedBase from a given state dict with any objects. < < Args: < sharded_state_dict: state dict possibly containing ShardedBase objects < < Returns: < Tuple[ShardedStateDict, StateDict]: tuple of: < - state dict with all ShardedBase objects (keeping the original state dict structure) < - state dict with all other objects (keeping the original state dict structure) < """ 134,145d81 < """Extract a dict consisting of only LocalNonpersistentObjects from a given state dict. < < Args: < sharded_state_dict: state dict possibly containing LocalNonpersistentObjects < < Returns: < Tuple[ShardedStateDict, StateDict]: tuple of: < - state dict with all LocalNonpersistentObjects < (keeping the original state dict structure) < - state dict with all other objects (keeping the original state dict structure) < """ < 201,202c137 < prefix_map (Dict[str, str]): < map of old->new prefixes. The first matching prefix for each key is used --- > prefix_map (Dict[str, str]): map of old->new prefixes. The first matching prefix for each key is used diff -rN ./megatron/core/distributed/distributed_data_parallel_config.py ../megatron-lm/megatron/core/distributed/distributed_data_parallel_config.py 17,24d16 < overlap_param_gather: bool = False < """If true, overlap param all-gather with forward compute.""" < < align_param_gather: bool = False < """If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each < PP stage will independently launch as needed. < """ < 41,44d32 < < fp8_param_gather: bool = False < """If true, keep the compute param in fp8 (do not use any other intermediate dtype) and < perform the param all-gather in fp8.""" diff -rN ./megatron/core/distributed/distributed_data_parallel.py ../megatron-lm/megatron/core/distributed/distributed_data_parallel.py 4a5 > from typing import Dict, Optional 12c13 < from ..utils import is_float8tensor, log_single_rank --- > from ..utils import log_single_rank 14c15 < from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets --- > from .param_and_grad_buffer import ParamAndGradBuffer 79c80,81 < self.param_to_bucket_group = {} --- > self.module = module > self.param_to_buffer = {} 85d86 < self.params_with_grad = [] 90,93d90 < # Track params with grad to enable direct setting < # of param.grad_added_to_main_grad < self.params_with_grad.append(param) < 102c99 < def _allocate_buffers_for_parameters( --- > def allocate_buffers_for_parameters( 106,107d102 < param_and_grad_dtype_to_offsets = {} < param_and_grad_dtype_to_indices = {} 111c106,107 < assert param.requires_grad --- > if not param.requires_grad: > continue 114,121d109 < if is_float8tensor(param): < # Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake" < # dtype (usually a higher precision dtype such as bfloat16), but its actual < # data is stored in the form of a torch uint8 tensor within the Float8Tensor's < # ".data" attribute. Therefore, when creating the param buffer for fp8 params, < # it is necessary to use torch.uint8, not the "fake" dtype got from < # "param.dtype". < param_dtype = torch.uint8 128,148d115 < # Get the index of each param among the params with same dtype, if a param is fp8, < # use its "fake" high precision dtype to find which params have same dtype with it. < # For example: < # Case 1: < # params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)] < # param_and_grad_dtype_to_indices = { < # (torch.bfloat16, torch.float32): [0, 1, 2, 3], < # } < # Case 2: < # params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)] < # param_and_grad_dtype_to_indices = { < # (torch.bfloat16, torch.float32): [0, 3], < # (torch.uint8, torch.float32): [1, 2], < # } < # We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode. < offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0) < param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1 < indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), []) < indices.append(offset) < param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices < 150,152c117 < target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size( < with_context_parallel=True < ) --- > target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size() 167c132 < _ParamAndGradBuffer( --- > ParamAndGradBuffer( 176d140 < param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)], 178a143,144 > for param in params: > self.param_to_buffer[param] = buffers[-1] 180,206c146 < # In some scenarios, we want to put buckets from different buffers into a group so that < # their communication can be aggregated. For example, when there are both fp8 buffers < # and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8 < # bucket and a bf16 bucket, which doubles the number of communication kernels, and < # because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back < # communications will prevent the overlap of the communication kernels with computation < # kernels. < # If bucketing is explicitly disabled, then put all buckets in a buffer into a single < # bucket group. < bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing) < < # Set `next_param_gather_bucket_group` for different bucket groups by iterating through < # buckets in reverse order (since all-gathers happen in reverse order of buckets). < if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather: < num_bucket_groups = len(bucket_groups) < for i in range(1, num_bucket_groups): < bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( < bucket_groups[num_bucket_groups - i - 1] < ) < < # Create map from param to bucket group, used in pre_hook. < for bucket_group in bucket_groups: < for bucket in bucket_group.buckets: < for param in bucket.params_list: < self.param_to_bucket_group[param] = bucket_group < < return buffers, bucket_groups --- > return buffers 218,220c158 < data_parallel_world_size = parallel_state.get_data_parallel_world_size( < with_context_parallel=True < ) --- > data_parallel_world_size = parallel_state.get_data_parallel_world_size() 225c163 < self.buffers, self.bucket_groups = _allocate_buffers_for_parameters( --- > self.buffers = allocate_buffers_for_parameters( 232,237c170,173 < self.expert_parallel_buffers, self.expert_parallel_bucket_groups = ( < _allocate_buffers_for_parameters( < expert_parallel_params, < parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True), < gradient_scaling_factor=expert_gradient_scaling_factor, < ) --- > self.expert_parallel_buffers = allocate_buffers_for_parameters( > expert_parallel_params, > parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True), > gradient_scaling_factor=expert_gradient_scaling_factor, 263c199 < grad_acc.register_hook(self._make_backward_post_hook(param)) --- > grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer)) 266,300d201 < self.use_forward_hook = ( < self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather < ) < self.remove_forward_pre_hook_handles = {} < if self.use_forward_hook: < self.enable_forward_pre_hook() < self.overlap_param_gather_with_optimizer_step = False < < def enable_forward_pre_hook(self): < """ < Enable forward pre-hooks needed for param all-gather overlap with forward compute. < """ < assert self.use_forward_hook < assert len(self.remove_forward_pre_hook_handles) == 0 < # Register forward pre-hook for all sub-modules. < for module in self.module.modules(): < self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook( < self._make_forward_pre_hook() < ) < < def disable_forward_pre_hook(self): < """ < Disable forward pre-hooks needed for param all-gather overlap with forward compute. < """ < assert self.use_forward_hook < # De-register forward pre-hook for all sub-modules. < for module in self.module.modules(): < assert self.remove_forward_pre_hook_handles[module] is not None < self.remove_forward_pre_hook_handles[module].remove() < del self.remove_forward_pre_hook_handles[module] < assert len(self.remove_forward_pre_hook_handles) == 0 < < # Force synchronize parameters. < self.start_param_sync(force_sync=True) < 307,340c208,212 < def _make_forward_pre_hook(self): < """ < Create a forward pre-hook to wait on all-gather handles when necessary (i.e., < when a module uses a parameter in a bucket with a still incomplete all-gather). < """ < < def hook(module, *unused): < assert ( < self.use_forward_hook < ), "Should use pre-hook only when overlap_param_gather is True" < < # Make sure all parameters in this module have been all-gathered as necessary. < for param in module.parameters(recurse=False): < # Skip parameters without an associated buffer (such parameters have a < # .requires_grad field equal to False). < if param not in self.param_to_bucket_group: < continue < assert param.requires_grad < < # If aligning param all-gather across pipeline stages, all-gather is dispatched < # by start_param_sync calls in core/pipeline_parallelism/schedules.py. < # If overlapping param all-gather with optimizer step, then all-gather has < # already been dispatched in optimizer step. < skip_next_bucket_dispatch = ( < self.ddp_config.align_param_gather < or self.overlap_param_gather_with_optimizer_step < ) < self.param_to_bucket_group[param].finish_param_sync( < skip_next_bucket_dispatch=skip_next_bucket_dispatch < ) < < return hook < < def _make_backward_post_hook(self, param: torch.nn.Parameter): --- > def _make_param_hook( > self, > param: torch.nn.Parameter, > param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer], > ): 342,344c214 < Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when < ready (i.e., when all grads in a bucket have been computed in all microbatches < in a batch). --- > Creates the all-reduce / reduce-scatter hook for backprop. 347,349c217,218 < def hook(*unused): < if param in self.param_to_bucket_group: < assert param.requires_grad --- > def param_hook(*unused): > if param.requires_grad: 361c230 < self.param_to_bucket_group[param].register_grad_ready(param) --- > param_to_buffer[param].register_grad_ready(param) 363c232 < return hook --- > return param_hook 370,371c239,240 < for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: < bucket_group.is_last_microbatch = False --- > for buffer in self.buffers + self.expert_parallel_buffers: > buffer.is_last_microbatch = False 375,398c244,245 < for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: < bucket_group.is_last_microbatch = True < < def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False): < """ < Initiates param sync (all-gather) communication operations for all model parameters. < < By default, when overlap_param_gather is set to True, dispatches asynchronous communication < calls; when overlap_param_gather is set to False, calls synchronous communication < ops. Can override this default behavior using flags below. < < Args: < force_sync (bool, optional): force synchronous collective regardless of < other settings. < force_dispatch (bool, optional): force dispatch regardless of other settings. < """ < if not force_sync: < # If overlapping param AG with optimizer step, AG should not be dispatched again < # in forward_backward_step. < if self.overlap_param_gather_with_optimizer_step and not force_dispatch: < return < < for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: < bucket_group.start_param_sync(force_sync=force_sync) --- > for buffer in self.buffers + self.expert_parallel_buffers: > buffer.is_last_microbatch = True 409,410c256,262 < for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: < bucket_group.start_grad_sync() --- > for buffer in self.buffers + self.expert_parallel_buffers: > buffer.start_grad_sync() > > def scale_gradients(self, scaling_factor: float) -> None: > """Scale all gradients inside the buffers by `scaling_factor`.""" > for buffer in self.buffers + self.expert_parallel_buffers: > buffer.scale_gradients(scaling_factor) 421,425d272 < for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: < bucket_group.finish_grad_sync() < < def scale_gradients(self, scaling_factor: float): < """Scale all gradients inside the buffers by `scaling_factor`.""" 427c274 < buffer.scale_gradients(scaling_factor) --- > buffer.finish_grad_sync() 434,435c281,283 < for param in self.params_with_grad: < param.grad_added_to_main_grad = False --- > for param in self.module.parameters(): > if param.requires_grad: > param.grad_added_to_main_grad = False 438,439d285 < for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: < bucket_group.reset() diff -rN ./megatron/core/distributed/__init__.py ../megatron-lm/megatron/core/distributed/__init__.py 6,10c6 < < # For backwards compatibility. ParamAndGradBuffer will be deprecated in future release. < # ParamAndGradBuffer (which is an alias of _ParamAndGradBuffer) is not intended to be < # consumed directly by external code. < from .param_and_grad_buffer import ParamAndGradBuffer --- > from .param_and_grad_buffer import ParamAndGradBuffer, shard_buffer diff -rN ./megatron/core/distributed/param_and_grad_buffer.py ../megatron-lm/megatron/core/distributed/param_and_grad_buffer.py 6d5 < import warnings 11d9 < from torch.distributed import _coalescing_manager 13c11 < from ..utils import is_float8tensor, log_on_each_pipeline_stage --- > from ..utils import log_on_each_pipeline_stage 40c38 < class _ParamAndGradBucket: --- > class Bucket: 42c40,42 < Bucket to keep track of a subset of the model's parameters and gradients. --- > Bucket to keep track of a subset of the model's gradients. Provides functionality to register > when params in the bucket have grads ready to be synced; an asynchronous communication call > is automatically launched when _all_ params in the bucket have grads ready. 44a45 > ddp_config: DistributedDataParallel config object. 49a51,52 > data_parallel_group: Data-parallel process group. > data_parallel_world_size: World size using the data-parallel group group. 53d55 < bucket_id: Index of bucket in buffer. 57a60 > ddp_config: DistributedDataParallelConfig, 62a66,67 > data_parallel_group: torch.distributed.ProcessGroup, > data_parallel_world_size: int, 64d68 < bucket_id: int, 65a70,75 > self.ddp_config = ddp_config > > # State for bookkeeping: params is the set of parameters this bucket is > # responsible for, params_with_grad is the set of parameters with grads > # available. When overlap_grad_reduce is True, communication (all-reduce > # or reduce-scatter) is issued when params_with_grad equals params. 68,69c78 < # Make sure there are no duplicate params. < assert len(self.params_list) == len(self.params) --- > self.params_with_grad = set() 76,102d84 < self.gradient_scaling_factor = gradient_scaling_factor < self.bucket_id = bucket_id < < < class _ParamAndGradBucketGroup: < """ < Put multiple buckets into a group so that their communications can be aggregated together. < Provides functionality to register when params in the bucket group have grads ready to be < synced; an asynchronous communication call is automatically launched when _all_ params in < the bucket group have grads ready. < < Args: < buckets: A list of buckets. < ddp_config: DistributedDataParallel config object. < data_parallel_group: Data-parallel process group. < data_parallel_world_size: World size using the data-parallel group group. < """ < < def __init__( < self, < buckets: List[_ParamAndGradBucket], < ddp_config: DistributedDataParallelConfig, < data_parallel_group: torch.distributed.ProcessGroup, < data_parallel_world_size: int, < ): < self.buckets = buckets < self.ddp_config = ddp_config 106,118c88 < < # State for bookkeeping: params is the set of parameters this bucket group is < # responsible for, params_with_grad is the set of parameters with grads < # available. When overlap_grad_reduce is True, communication (all-reduce < # or reduce-scatter) is issued when params_with_grad equals params. < self.param_to_bucket = {} < self.params = set() < for bucket in self.buckets: < for param in bucket.params_list: < self.param_to_bucket[param] = bucket < self.params.add(param) < < self.next_param_gather_bucket_group = None --- > self.gradient_scaling_factor = gradient_scaling_factor 121,123d90 < self.param_gather_handle = None < self.param_gather_dispatched = False < self.grad_reduce_handle = None 127c94 < Reset metadata in bucket group in preparation for the next iteration of training. --- > Reset metadata in bucket in preparation for the next iteration of training. 130,221c97,98 < self.is_last_microbatch = True < < def check_for_nan_in_grad(self): < """ < Make sure norm of grads in bucket are not NaN prior to data-parallel < all-reduce / reduce-scatter. < """ < global_rank = torch.distributed.get_rank() < norm_is_nan = self.buckets[0].grad_data.norm(p=2).isnan() < for i in range(1, len(self.buckets)): < norm_is_nan.logical_or_(self.buckets[i].grad_data.norm(p=2).isnan()) < assert not norm_is_nan, ( < f'Rank {global_rank}: found NaN in local grad norm in ' < f'backward pass before data-parallel communication collective. ' < f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' < ) < < def start_param_sync(self, force_sync: bool = False): < """ < Initiates all necessary param all-gathers for this bucket. < < When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous < communication call (unless force_sync is True). When ddp_config.overlap_param_gather < is set to False, makes synchronous call. < < Args: < force_sync (bool, optional): force synchronous collective regardless of < other settings if true. < """ < assert self.ddp_config.use_distributed_optimizer < < if force_sync: < if self.param_gather_handle is not None: < self.param_gather_handle.wait() < self.param_gather_handle = None < return < else: < assert self.param_gather_handle is None < < async_op = self.ddp_config.overlap_param_gather and not force_sync < # Coalesce communication kernels across buckets in the bucket group. < with _coalescing_manager(self.data_parallel_group, async_ops=async_op) as cm: < for bucket in self.buckets: < local_data_view = shard_buffer(bucket.param_data, self.data_parallel_world_size)[ < self.data_parallel_rank < ] < torch.distributed._all_gather_base( < bucket.param_data, < local_data_view, < group=self.data_parallel_group, < async_op=async_op, < ) < if async_op: < self.param_gather_handle = cm < else: < # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, < # `cm` is not None, which is different from when `_coalescing_manager` is not used in < # which case the torch.distributed._all_gather_base() will return None. In order to < # maintain consistency with prior code, we need to manually set communication handle to < # None. < self.param_gather_handle = None < self.param_gather_dispatched = True < < def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): < """ < Finishes param sync communication operation for this bucket. Dispatches < next bucket's param sync if available, unless skip_next_bucket_dispatch < is True. < < When ddp_config.overlap_param_gather is set to True, waits for asynchronous < communication call to complete (and dispatches one if one is not already < outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to < False. < < Args: < skip_next_bucket_dispatch (bool, optional): if true, dispatch next < bucket's communication if available. < """ < assert self.ddp_config.use_distributed_optimizer < assert self.ddp_config.overlap_param_gather < < # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first < # AG bucket in first model chunk if ddp_config.align_param_gather is False). < if not self.param_gather_dispatched: < self.start_param_sync() < < if self.param_gather_handle is not None: < self.param_gather_handle.wait() < self.param_gather_handle = None < # Dispatch next bucket's asynchronous param AG. < if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch: < self.next_param_gather_bucket_group.start_param_sync() --- > self.communication_handle = None > self.is_communication_outstanding = False 225,226c102,103 < Initiates grad sync (all-reduce or reduce-scatter) communication operations < for all buckets in the bucket group. --- > Initiates grad sync (all-reduce or reduce-scatter) communication operation > for this bucket. 228,229c105,106 < When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous < communication call. When ddp_config.overlap_grad_reduce is set to False, makes --- > When overlap_grad_reduce is set to True, dispatches an asynchronous > communication call. When overlap_grad_reduce is set to False, makes 233c110 < self.grad_reduce_handle is None --- > self.communication_handle is None and not self.is_communication_outstanding 235a113,114 > # Make sure norm of grads in bucket are not NaN > # prior to data-parallel all-reduce / reduce-scatter. 237c116,122 < self.check_for_nan_in_grad() --- > global_rank = torch.distributed.get_rank() > norm = self.grad_data.norm(p=2) > assert not norm.isnan(), ( > f'Rank {global_rank}: found NaN in local grad norm in ' > f'backward pass before data-parallel communication collective. ' > f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' > ) 241,243c126,127 < for bucket in self.buckets: < if bucket.gradient_scaling_factor != 1.0: < bucket.grad_data *= bucket.gradient_scaling_factor --- > if self.gradient_scaling_factor != 1.0: > self.grad_data *= self.gradient_scaling_factor 250,274c134,154 < # Use async communications only when overlap_grad_reduce is True. < async_op = self.ddp_config.overlap_grad_reduce < # Coalesce communication kernels across buckets in the bucket group. < with _coalescing_manager(self.data_parallel_group, async_ops=async_op) as cm: < for bucket in self.buckets: < if self.ddp_config.use_distributed_optimizer: < local_data_view = shard_buffer(bucket.grad_data, self.data_parallel_world_size)[ < self.data_parallel_rank < ] < torch.distributed._reduce_scatter_base( < local_data_view, < bucket.grad_data, < op=reduce_op, < group=self.data_parallel_group, < async_op=async_op, < ) < else: < torch.distributed.all_reduce( < bucket.grad_data, < op=reduce_op, < group=self.data_parallel_group, < async_op=async_op, < ) < if async_op: < self.grad_reduce_handle = cm --- > # Use async_op only when overlap_grad_reduce is True. > if self.ddp_config.use_distributed_optimizer: > local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[ > self.data_parallel_rank > ] > self.communication_handle = torch.distributed._reduce_scatter_base( > local_data_view, > self.grad_data, > op=reduce_op, > group=self.data_parallel_group, > async_op=self.ddp_config.overlap_grad_reduce, > ) > else: > self.communication_handle = torch.distributed.all_reduce( > self.grad_data, > op=reduce_op, > group=self.data_parallel_group, > async_op=self.ddp_config.overlap_grad_reduce, > ) > if self.ddp_config.overlap_grad_reduce: > self.is_communication_outstanding = True 276,281c156 < # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, < # `cm` is not None, which is different from when `_coalescing_manager` is not used in < # which case the torch.distributed._reduce_scatter_base() will return None. In order to < # maintain consistency with prior code, we need to manually set communication handle to < # None. < self.grad_reduce_handle = None --- > self.is_communication_outstanding = False 285,286c160,161 < Finishes grad sync (all-reduce or reduce-scatter) communication operations < for all buckets in the bucket group. --- > Finishes grad sync (all-reduce or reduce-scatter) communication operation > for this bucket. 288,290c163,164 < When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous < communication call to complete. When ddp_config.overlap_grad_reduce is set to False, < makes synchronous call. --- > When overlap_grad_reduce is set to True, waits for asynchronous communication > call to complete. When overlap_grad_reduce is set to False, makes synchronous call. 293d166 < self.param_gather_dispatched = False 297c170 < assert self.grad_reduce_handle is not None, ( --- > assert self.communication_handle is not None and self.is_communication_outstanding, ( 301,302c174 < self.grad_reduce_handle.wait() < self.grad_reduce_handle = None --- > self.communication_handle.wait() 309,310c181 < grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce < is True. --- > grads as ready when processing the last microbatch and overlap_grad_reduce is True. 311a183,184 > assert param in self.params, 'Param is not in the bucket' > assert param not in self.params_with_grad, 'Cannot set grad twice' 314,321c187,191 < ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' < if self.is_last_microbatch: < assert param in self.param_to_bucket, 'Param is not in the bucket group' < assert param not in self.params_with_grad, 'Cannot set grad twice' < self.params_with_grad.add(param) < # If all params in bucket group have grads available, issue communication call. < if len(self.params_with_grad) == len(self.params): < self.start_grad_sync() --- > ), 'register_grad_ready() should be called only when overlapping grad reduce' > self.params_with_grad.add(param) > # If all params in bucket have grads available, issue communication call. > if len(self.params_with_grad) == len(self.params): > self.start_grad_sync() 324c194 < class _ParamAndGradBuffer: --- > class ParamAndGradBuffer: 341,343d210 < param_indices: The index of each param among the params with same dtype, if a param is fp8, < use its "fake" high precision dtype to determine which params have same dtype with it. < These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode. 356d222 < param_indices: List[int], 359,360d224 < self.params = params < self.param_indices = param_indices 376a241 > self.is_last_microbatch = True 411,412c276,277 < param_start_index = 0 < bucket_start_index = param_start_index --- > data_start_index = 0 > bucket_data_start_index = data_start_index 418c283 < def _update_bucket_metadata(param_end_index: int) -> int: --- > def _create_new_bucket(data_end_index: int) -> int: 420,421c285,286 < Record metadata for the bucket starting at bucket_start_index and ending with the < passed-in param_end_index. Returns the bucket's end_index. --- > Create the bucket_id'th bucket with collected bucket_params, starting at > bucket_data_start_index. 423,431c288,294 < nonlocal bucket_start_index, bucket_params, bucket_id < per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) < bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) < < # Record metadata of new bucket. < self.bucket_indices.append((bucket_start_index, bucket_end_index)) < bucket_start_index = bucket_end_index < < # Prepare for next bucket. --- > nonlocal bucket_data_start_index, bucket_params, bucket_id > per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index) > data_end_index = _pad_end_of_bucket_if_needed(data_end_index) > # Update bucket metadata. > self.bucket_indices.append((bucket_data_start_index, data_end_index)) > bucket_data_start_index = data_end_index > # Re-set bucket_params and increment bucket_id for next bucket. 434,436c297,298 < < # Return the potentially padded bucket_end_index. < return bucket_end_index --- > # Return the potentially padded data_end_index. > return data_end_index 452c314,317 < # Iterate through parameters in reverse order to roughly follow backprop order. --- > # Iterate through parameters in reverse order to roughly follow backprop order, > # and skip parameters that don't require gradients. > if not param.requires_grad: > continue 455c320 < param_start_index = _pad_start_of_param_if_needed(param_start_index) --- > data_start_index = _pad_start_of_param_if_needed(data_start_index) 460c325 < # end at the current param_start_index. --- > # end at the current data_start_index. 463,464c328,329 < if param_start_index % self.data_parallel_world_size != 0: < param_start_index = _pad_end_of_bucket_if_needed(param_start_index) --- > if data_start_index % self.data_parallel_world_size != 0: > data_start_index = _pad_end_of_bucket_if_needed(data_start_index) 466c331 < bucket_end_index = _update_bucket_metadata(param_start_index) --- > _create_new_bucket(data_start_index) 468,469c333,334 < param_end_index = param_start_index + this_numel < self.param_index_map[param] = (param_start_index, param_end_index, bucket_id) --- > data_end_index = data_start_index + this_numel > self.param_index_map[param] = (data_start_index, data_end_index, bucket_id) 475c340,341 < bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size --- > bucket_size is not None > and (data_end_index - bucket_data_start_index) >= bucket_size 477,480c343,344 < bucket_end_index = _update_bucket_metadata(param_end_index) < param_start_index = bucket_end_index < else: < param_start_index = param_end_index --- > data_end_index = _create_new_bucket(data_end_index) > data_start_index = data_end_index 484c348 < bucket_end_index = _update_bucket_metadata(param_end_index) --- > data_end_index = _create_new_bucket(data_end_index) 488c352 < self.numel = bucket_end_index --- > self.numel = data_end_index 513,514c377,378 < bucket_params = [] < bucket_start_index = 0 --- > bucket_params = set() > bucket_data_start_index = 0 517c381,383 < param_start_index, param_end_index, bucket_id = self.param_index_map[param] --- > if not param.requires_grad: > continue > data_start_index, data_end_index, bucket_id = self.param_index_map[param] 522,523c388,389 < new_param_data = self._get( < param.data.shape, param_start_index, buffer_type=BufferType.PARAM --- > param.data = self._get( > param.data.shape, data_start_index, buffer_type=BufferType.PARAM 525,528d390 < if is_float8tensor(param): < param._data = new_param_data < else: < param.data = new_param_data 535c397 < param.data.shape, param_start_index, buffer_type=BufferType.GRAD --- > param.data.shape, data_start_index, buffer_type=BufferType.GRAD 538,546c400,406 < bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index) < self.buckets.append( < self._new_bucket( < bucket_params=bucket_params, < start_index=bucket_start_index, < end_index=bucket_end_index, < numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], < bucket_id=cur_bucket_id, < ) --- > bucket_data_end_index = _pad_end_of_bucket_if_needed(data_start_index) > self._set_bucket( > bucket_params=bucket_params, > start_index=bucket_data_start_index, > end_index=bucket_data_end_index, > numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], > bucket_id=cur_bucket_id, 548,549c408,409 < bucket_start_index = bucket_end_index < bucket_params = [] --- > bucket_data_start_index = bucket_data_end_index > bucket_params = set() 553c413 < bucket_params.append(param) --- > bucket_params.add(param) 557,565c417,423 < bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) < self.buckets.append( < self._new_bucket( < bucket_params=bucket_params, < start_index=bucket_start_index, < end_index=bucket_end_index, < numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], < bucket_id=cur_bucket_id, < ) --- > bucket_data_end_index = _pad_end_of_bucket_if_needed(data_end_index) > self._set_bucket( > bucket_params=bucket_params, > start_index=bucket_data_start_index, > end_index=bucket_data_end_index, > numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], > bucket_id=cur_bucket_id, 603c461 < def _new_bucket( --- > def _set_bucket( 610c468 < ) -> _ParamAndGradBucket: --- > ): 612c470,471 < Helper function that creates a new bucket. Also updates param->bucket mapping. --- > Helper function to create new bucket, add it to list of buckets, and > also update param->bucket mapping. 631c490,491 < bucket = _ParamAndGradBucket( --- > bucket = Bucket( > ddp_config=self.ddp_config, 636a497,498 > data_parallel_group=self.data_parallel_group, > data_parallel_world_size=self.data_parallel_world_size, 638d499 < bucket_id=bucket_id, 639a501 > self.buckets.append(bucket) 644,645d505 < return bucket < 648c508,509 < Zero out the underlying grad_buffer. --- > Zero out the underlying grad_buffer and reset all buckets in preparation for the next > iteration of training. 650a512,514 > for bucket in self.buckets: > bucket.reset() > self.is_last_microbatch = True 651a516,519 > def start_grad_sync(self): > """ > Initiates grad sync (all-reduce or reduce-scatter) communication operations > for all buckets in the grad buffer. 653,711c521,526 < def partition_buckets( < buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False < ) -> List[_ParamAndGradBucketGroup]: < """ < Automatically regroup the buckets of input buffers and return a list of bucket groups. < < In some scenarios, we need to put buckets from different buffers into a group so that their < communication can be aggregated. < < For example, when there are both fp8 weights and bf16 biases in the model and virtual < pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket, < which doubles the number of communication kernels, and because of the use of < CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the < overlap of communication kernels with computation kernels. < < The grouping strategy is: < 1. If force_single_bucket_group is True, put all buckets across all buffers into a single < bucket group. < 2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers, < let each bucket group have only one bucket. < 3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets < into the last fp8 bucket group. < - Since the non-fp8 parameters (typically the biases of various layers) are relatively < small, they are likely to be grouped into a single non-fp8 bucket. < - The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to < the end of the model, while the last bucket corresponds to the beginning. < - If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the < reduce-scatter to synchronize gradients after the backward pass at the end of the model < has completed. This is because we need to wait for the non-fp8 params from the beginning < layers to obtain their gradients. < - Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue. < < Args: < buffers (list): list of input buffers. < single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer < into a single bucket group. < """ < < if len(buffers) == 0: < return [] < < dtype_to_buffer_map = {} < for buffer in buffers: < dtype = buffer.param_dtype < # Make sure that the param_dtype of any two buffers is different. < assert dtype not in dtype_to_buffer_map < dtype_to_buffer_map[dtype] = buffer < < # Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True. < if force_single_bucket_group: < buckets = [] < ddp_config = buffers[0].ddp_config < data_parallel_group = buffers[0].data_parallel_group < data_parallel_world_size = buffers[0].data_parallel_world_size < for buffer in buffers: < assert ddp_config == buffer.ddp_config < assert data_parallel_group == buffer.data_parallel_group < assert data_parallel_world_size == buffer.data_parallel_world_size < buckets.extend(buffer.buckets) --- > When overlap_grad_reduce is set to True, dispatches asynchronous communication > calls. When overlap_grad_reduce is set to False, calls synchronous > communication ops. > """ > for bucket in self.buckets: > bucket.start_grad_sync() 713,716c528,531 < bucket_group = _ParamAndGradBucketGroup( < buckets, ddp_config, data_parallel_group, data_parallel_world_size < ) < return [bucket_group] --- > def finish_grad_sync(self): > """ > Finishes grad sync (all-reduce or reduce-scatter) communication operations > for all buckets in the grad buffer. 718,758c533,538 < if torch.uint8 not in dtype_to_buffer_map: < # Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have < # only one bucket. < bucket_groups = [] < for buffer in buffers: < for bucket in buffer.buckets: < bucket_groups.append( < _ParamAndGradBucketGroup( < [bucket], < buffer.ddp_config, < buffer.data_parallel_group, < buffer.data_parallel_world_size, < ) < ) < return bucket_groups < else: < # Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group. < non_fp8_buckets = [] < for buffer in buffers: < if buffer.param_dtype != torch.uint8: < for bucket in buffer.buckets: < non_fp8_buckets.append(bucket) < < bucket_groups = [] < fp8_buffer = dtype_to_buffer_map[torch.uint8] < for bucket in fp8_buffer.buckets: < if len(bucket_groups) == len(fp8_buffer.buckets) - 1: < # The last bucket group. < group_buckets = [bucket] + non_fp8_buckets < else: < # The first N-1 bucket groups. < group_buckets = [bucket] < bucket_groups.append( < _ParamAndGradBucketGroup( < group_buckets, < buffer.ddp_config, < buffer.data_parallel_group, < buffer.data_parallel_world_size, < ) < ) < return bucket_groups --- > When overlap_grad_reduce is set to True, waits for asynchronous communication > calls to complete. When overlap_grad_reduce is set to False, calls synchronous > communication ops. > """ > for bucket in self.buckets: > bucket.finish_grad_sync() 759a540,542 > def register_grad_ready(self, param: torch.nn.Parameter): > """ > Registers grads for the passed-in param to be "ready" for grad sync. 761,769c544,552 < # For backwards compatibility. ParamAndGradBuffer will be deprecated in future release. < # _ParamAndGradBuffer is not intended to be consumed directly by external code. < class ParamAndGradBuffer(_ParamAndGradBuffer): < def __init__(self, *args, **kwargs): < super().__init__(*args, **kwargs) < warnings.warn( < "`ParamAndGradBuffer` will be deprecated in a future release, and is not " < "intended to be used by external code." < ) --- > When the number of microbatches is greater than 1, we only want to register > grads as ready when processing the last microbatch and overlap_grad_reduce is True. > """ > assert ( > self.ddp_config.overlap_grad_reduce > ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' > if self.is_last_microbatch: > bucket = self.param_to_bucket[param] > bucket.register_grad_ready(param) Binary files ./megatron/core/distributed/__pycache__/distributed_data_parallel_config.cpython-310.pyc and ../megatron-lm/megatron/core/distributed/__pycache__/distributed_data_parallel_config.cpython-310.pyc differ Binary files ./megatron/core/distributed/__pycache__/distributed_data_parallel.cpython-310.pyc and ../megatron-lm/megatron/core/distributed/__pycache__/distributed_data_parallel.cpython-310.pyc differ Binary files ./megatron/core/distributed/__pycache__/finalize_model_grads.cpython-310.pyc and ../megatron-lm/megatron/core/distributed/__pycache__/finalize_model_grads.cpython-310.pyc differ Binary files ./megatron/core/distributed/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/distributed/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/distributed/__pycache__/param_and_grad_buffer.cpython-310.pyc and ../megatron-lm/megatron/core/distributed/__pycache__/param_and_grad_buffer.cpython-310.pyc differ diff -rN ./megatron/core/extensions/transformer_engine.py ../megatron-lm/megatron/core/extensions/transformer_engine.py 1,969d0 < # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. < < import dataclasses < import os < import warnings < from typing import Callable < < import torch < import transformer_engine as te < from packaging.version import Version as PkgVersion < from torch import Tensor < < from megatron.core import ModelParallelConfig, parallel_state < from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding < from megatron.core.packed_seq_params import PackedSeqParams < from megatron.core.parallel_state import ( < get_context_parallel_global_ranks, < get_context_parallel_group, < get_tensor_and_expert_parallel_world_size, < get_tensor_model_parallel_group, < ) < from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name < from megatron.core.tensor_parallel.utils import divide < from megatron.core.transformer.enums import AttnMaskType < from megatron.core.transformer.transformer_config import TransformerConfig < from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint < from megatron.core.utils import get_te_version, is_te_min_version < < < def _get_extra_te_kwargs(config: TransformerConfig): < extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} < < if is_te_min_version("0.12.0"): < if config.use_cpu_initialization: < extra_transformer_engine_kwargs["device"] = 'cpu' < else: < extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() < return extra_transformer_engine_kwargs < < < def condition_init_method(config, init_method): < """Condition TE init_method on config.perform_initialization.""" < return init_method if config.perform_initialization else (lambda w: None) < < < class TENorm: < """ < A conditional wrapper to initialize an instance of Transformer-Engine's < `LayerNorm` or `RMSNorm` based on input < """ < < # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? < def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): < if config.normalization == "LayerNorm": < instance = te.pytorch.LayerNorm( < hidden_size=hidden_size, < eps=eps, < sequence_parallel=config.sequence_parallel, < zero_centered_gamma=config.layernorm_zero_centered_gamma, < **_get_extra_te_kwargs(config), < ) < elif config.normalization == "RMSNorm": < assert hasattr( < te.pytorch, "RMSNorm" < ), "Transformer-Engine >= v0.11 required to use this feature" < instance = te.pytorch.RMSNorm( < hidden_size=hidden_size, < eps=eps, < sequence_parallel=config.sequence_parallel, < zero_centered_gamma=config.layernorm_zero_centered_gamma, < **_get_extra_te_kwargs(config), < ) < else: < raise Exception('Only LayerNorm and RMSNorm are curently supported') < < return instance < < < class TELinear(te.pytorch.Linear): < """ < Wrapper for the Transformer-Engine's `Linear` layer. < < Note that if Megatron's parallel_state has not been initialized < yet, the tp_group passed to TE will be None and must be set later < via set_tensor_parallel_group(). < """ < < def __init__( < self, < input_size: int, < output_size: int, < *, < parallel_mode: str, < config: ModelParallelConfig, < init_method: Callable, < bias: bool, < skip_bias_add: bool, < skip_weight_param_allocation: bool, < tp_comm_buffer_name: str = None, < is_expert: bool = False, < ): < self.config = config < < # TE returns a zero length Tensor when bias=False and < # return_bias=True, but we prefer None. So in that case we < # tell TE to not return the bias, and return None < # ourselves. This way our forward always returns two values < # and we don't have to deal with the zero length Tensor. < self.te_return_bias = skip_bias_add and bias < self.is_first_microbatch = True < self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache < if skip_weight_param_allocation: < raise ValueError( < 'Transformer Engine linear layers do not support skip_weight_param_allocation' < ) < < extra_kwargs = _get_extra_te_kwargs(config) < < if is_te_min_version("0.8.0"): < if self.config.tp_comm_overlap: < if is_te_min_version("1.5.0"): < # Use old overlap flags if they were supplied instead < extra_kwargs["ub_overlap_ag"] = ( < self.config.tp_comm_overlap_ag < if hasattr(self.config, "tp_comm_overlap_ag") < else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag < ) < extra_kwargs["ub_overlap_rs"] = ( < self.config.tp_comm_overlap_rs < if hasattr(self.config, "tp_comm_overlap_rs") < else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs < ) < # Disable ub overlap for experts. < if is_expert: < extra_kwargs["ub_overlap_ag"] = False < extra_kwargs["ub_overlap_rs"] = False < else: < extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag < extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag < extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs < extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs < # Disable ub overlap for experts. < if is_expert: < extra_kwargs["ub_split_ag"] = False < extra_kwargs["ub_atomic_gemm_ag"] = False < extra_kwargs["ub_split_rs"] = False < extra_kwargs["ub_atomic_gemm_rs"] = False < if is_te_min_version("1.0.0", check_equality=False): < assert ( < tp_comm_buffer_name is not None < ), "Buffer name should be set to configure communication overlap settings" < extra_kwargs["ub_name"] = tp_comm_buffer_name < < self.expert_parallel = self.config.expert_model_parallel_size > 1 < if is_expert and self.expert_parallel: < rng_tracker_name = get_expert_parallel_rng_tracker_name() < else: < rng_tracker_name = None < if is_te_min_version("1.7.0"): < extra_kwargs["rng_tracker_name"] = rng_tracker_name < < # Disable communications in TE when using SP or EP by making TE agnostic of model parallel. < tp_size = self.config.tensor_model_parallel_size < tp_group = get_tensor_model_parallel_group(check_initialized=False) < if is_expert and (self.config.sequence_parallel or self.expert_parallel): < if self.config.moe_extended_tp: < tp_size = get_tensor_and_expert_parallel_world_size() < if parallel_mode == "column": < output_size = divide(output_size, tp_size) < elif parallel_mode == "row": < input_size = divide(input_size, tp_size) < parallel_mode = None < tp_size = 1 < tp_group = None < < super().__init__( < in_features=input_size, < out_features=output_size, < sequence_parallel=self.config.sequence_parallel, < fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, < tp_group=tp_group, < tp_size=tp_size, < get_rng_state_tracker=( < get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None < ), < init_method=condition_init_method(config, init_method), < bias=bias, < return_bias=self.te_return_bias, < parallel_mode=parallel_mode, < **extra_kwargs, < ) < < for param in self.parameters(): < setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) < < def forward(self, x): < """Forward.""" < _is_first_microbatch = ( < None if self.disable_parameter_transpose_cache else self.is_first_microbatch < ) < out = super().forward(x, is_first_microbatch=_is_first_microbatch) < self.is_first_microbatch = False < < # TE only returns a tuple when return_bias is True, otherwise < # it returns a single Tensor, we always want to return two < # values regardless of the arguments. < if self.te_return_bias: < return out < return out, None < < < class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): < """ < Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines < layernorm and linear layers < """ < < def __init__( < self, < input_size: int, < output_size: int, < *, < config: TransformerConfig, < init_method: Callable, < gather_output: bool, < bias: bool, < skip_bias_add: bool, < is_expert: bool, < skip_weight_param_allocation: bool = False, < tp_comm_buffer_name: str = None, < ): < self.config = config < < if gather_output: < raise ValueError('Transformer Engine linear layers do not support gather_output = True') < < if is_expert: < raise ValueError('Transformer Engine linear layers do not yet support MoE') < < if skip_weight_param_allocation: < raise ValueError( < 'Transformer Engine linear layers do not support skip_weight_param_allocation' < ) < < # TE returns a zero length Tensor when bias=False and < # return_bias=True, but we prefer None. So in that case we < # tell TE to not return the bias, and return None < # ourselves. This way our forward always returns two values < # and we don't have to deal with the zero length Tensor. < self.te_return_bias = skip_bias_add and bias < self.is_first_microbatch = True < self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache < extra_kwargs = _get_extra_te_kwargs(config) < < # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` < if is_te_min_version("0.11.0"): < extra_kwargs["normalization"] = self.config.normalization < elif self.config.normalization != "LayerNorm": < te_version = get_te_version() < raise ValueError( < f"Transformer Engine v{te_version} does not support {self.config.normalization}." < ) < < if is_te_min_version("0.8.0"): < if self.config.tp_comm_overlap: < extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad < extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad < if is_te_min_version("1.5.0", check_equality=False): < # Use old overlap flags if they were supplied instead < extra_kwargs["ub_overlap_ag"] = ( < self.config.tp_comm_overlap_ag < if hasattr(self.config, "tp_comm_overlap_ag") < else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag < ) < if is_te_min_version("1.6.0.dev0", check_equality=False): < extra_kwargs["ub_overlap_rs_dgrad"] = ( < self.config.tp_comm_overlap_rs_dgrad < if hasattr(self.config, "tp_comm_overlap_rs_dgrad") < else False < ) < if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv: < extra_kwargs["ub_overlap_ag"] = False < extra_kwargs["ub_overlap_rs_dgrad"] = False < < if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1: < extra_kwargs["ub_overlap_ag"] = False < extra_kwargs["ub_overlap_rs_dgrad"] = False < else: < extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag < extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag < if is_te_min_version("1.0.0", check_equality=False): < assert ( < tp_comm_buffer_name is not None < ), "Buffer name should be set to configure communication overlap settings" < extra_kwargs["ub_name"] = tp_comm_buffer_name < < super().__init__( < in_features=input_size, < out_features=output_size, < eps=self.config.layernorm_epsilon, < sequence_parallel=self.config.sequence_parallel, < fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, < tp_group=get_tensor_model_parallel_group(check_initialized=False), < tp_size=self.config.tensor_model_parallel_size, < get_rng_state_tracker=( < get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None < ), < init_method=condition_init_method(config, init_method), < bias=bias, < return_bias=self.te_return_bias, < parallel_mode="column", < return_layernorm_output=False, < zero_centered_gamma=self.config.layernorm_zero_centered_gamma, < **extra_kwargs, < ) < < def forward(self, x): < """Forward.""" < _is_first_microbatch = ( < None if self.disable_parameter_transpose_cache else self.is_first_microbatch < ) < out = super().forward(x, is_first_microbatch=_is_first_microbatch) < self.is_first_microbatch = False < < # TE only returns a tuple when return_bias is True, otherwise < # it returns a single Tensor, we always want to return two < # values regardless of the arguments. < if self.te_return_bias: < return out < return out, None < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < """Sharding along axis 0, bias sharded""" < state_dict = self.state_dict(prefix='', keep_vars=True) < return make_sharded_tensors_for_checkpoint( < state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets < ) < < < class TEColumnParallelLinear(TELinear): < """ < Wrapper for the Transformer-Engine's `Linear` layer but specialized similar < to megatron's `ColumnParallelLinear` layer. < """ < < def __init__( < self, < input_size: int, < output_size: int, < *, < config: ModelParallelConfig, < init_method: Callable, < gather_output: bool, < bias: bool, < skip_bias_add: bool, < is_expert: bool, < skip_weight_param_allocation: bool = False, < tp_comm_buffer_name: str = None, < ): < if gather_output: < raise ValueError('Transformer Engine linear layers do not support gather_output = True') < < super().__init__( < input_size=input_size, < output_size=output_size, < parallel_mode="column", < config=config, < init_method=condition_init_method(config, init_method), < bias=bias, < skip_bias_add=skip_bias_add, < is_expert=is_expert, < skip_weight_param_allocation=skip_weight_param_allocation, < tp_comm_buffer_name=tp_comm_buffer_name, < ) < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < """Sharding along axis 0, bias sharded""" < state_dict = self.state_dict(prefix='', keep_vars=True) < return make_sharded_tensors_for_checkpoint( < state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets < ) < < < class TERowParallelLinear(TELinear): < """ < Wrapper for the Transformer-Engine's `Linear` layer but specialized similar < to megatron's `RowParallelLinear` layer. < """ < < def __init__( < self, < input_size: int, < output_size: int, < *, < config: ModelParallelConfig, < init_method: Callable, < bias: bool, < input_is_parallel: bool, < skip_bias_add: bool, < is_expert: bool, < tp_comm_buffer_name: str = None, < ): < if not input_is_parallel: < raise ValueError( < "Transformer Engine linear layers do not support input_is_parallel = False" < ) < < super().__init__( < input_size=input_size, < output_size=output_size, < parallel_mode="row", < config=config, < init_method=condition_init_method(config, init_method), < bias=bias, < skip_bias_add=skip_bias_add, < skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long < is_expert=is_expert, < tp_comm_buffer_name=tp_comm_buffer_name, < ) < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < """Sharding along axis 1, bias not sharded""" < state_dict = self.state_dict(prefix='', keep_vars=True) < return make_sharded_tensors_for_checkpoint( < state_dict, prefix, {'weight': 1}, sharded_offsets < ) < < < class TEDotProductAttention(te.pytorch.DotProductAttention): < """ < Wrapper for the Transformer-Engine's `DotProductAttention` layer that also < has "flash attention" enabled. < < Note that if Megatron's parallel_state has not been initialized yet, the < tp_group and cp_group passed to TE will be None and must be set later < via set_tensor_parallel_group() and set_context_parallel_group(). < """ < < cp_stream: torch.cuda.Stream = None < < def __init__( < self, < config: TransformerConfig, < layer_number: int, < attn_mask_type: AttnMaskType, < attention_type: str, < attention_dropout: float = None, < ): < self.config = config < self.te_forward_mask_type = False < self.qkv_format: str = 'sbhd' < < if self.config.apply_query_key_layer_scaling != bool( < int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) < ): < raise ValueError( < f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " < f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " < f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " < f"setting query key layer scaling via argument, so these two must match." < ) < < extra_kwargs = {} < if is_te_min_version("0.11.0"): < extra_kwargs["num_gqa_groups"] = self.config.num_query_groups < elif self.config.num_query_groups != self.config.num_attention_heads: < raise ValueError( < f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " < f"use a newer version of Transformer Engine. " < f"(num_query_groups ({self.config.num_query_groups}) != " < f"num_attention_heads ({self.config.num_attention_heads}))" < ) < < if is_te_min_version("0.10.0"): < extra_kwargs["attention_type"] = attention_type < # older version don't need attention_type < < if is_te_min_version("0.12.0", check_equality=False): < self.te_forward_mask_type = True < < # Only Transformer-Engine version >= 1.0.0 supports context parallelism < if is_te_min_version("1.0.0"): < if getattr(TEDotProductAttention, "cp_stream") is None: < TEDotProductAttention.cp_stream = torch.cuda.Stream() < extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) < extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( < check_initialized=False < ) < extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream < else: < assert ( < self.config.context_parallel_size == 1 < ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" < < if self.config.deterministic_mode: < if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: < raise RuntimeError( < "deterministic_mode is on and we are using DotProductAttention from " < "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. " < f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}." < ) < < if config.window_size is not None: < # Check version < assert is_te_min_version("1.2.0"), ( < f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" < "sliding window attention." < ) < extra_kwargs['window_size'] = config.window_size < < super().__init__( < num_attention_heads=self.config.num_attention_heads, < kv_channels=self.config.kv_channels, < attention_dropout=( < self.config.attention_dropout if attention_dropout is None else attention_dropout < ), < attn_mask_type=attn_mask_type.name, < sequence_parallel=self.config.sequence_parallel, < tp_size=self.config.tensor_model_parallel_size, < get_rng_state_tracker=( < get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None < ), < tp_group=get_tensor_model_parallel_group(check_initialized=False), < layer_number=layer_number, < **extra_kwargs, < ) < < def forward( < self, < query: Tensor, < key: Tensor, < value: Tensor, < attention_mask: Tensor, < attn_mask_type: AttnMaskType, < packed_seq_params: PackedSeqParams = None, < ): < """Forward.""" < packed_seq_kwargs = ( < dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} < ) < # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set < # after init < if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False): < self.qkv_format = 'bshd' < < qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) < < if get_te_version() < PkgVersion("1.3.0"): < # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H < # copies (#555) < # These two arguments did not exist prior to 1.3.0 < packed_seq_kwargs.pop("max_seqlen_q", None) < packed_seq_kwargs.pop("max_seqlen_kv", None) < < if self.config.apply_rope_fusion and qkv_format == 'bshd': < query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] < # In PyTorch, the following two tensors are in fact the same: < # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) < # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) < # Stride for a dimension that is 1 has no meaning, so tensors created two different ways < # can have same shape but different strides. < # We unify them to the first one to pass the stride check in TE < if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): < value = value.as_strided(value.shape, key.stride()) < < if self.te_forward_mask_type: < if qkv_format == 'thd' and is_te_min_version("1.7.0"): < # thd format uses flash attention with cuDNN kernel which requires is_padding=True, < # so the only acceptable mask types are `padding_causal` and `padding`. These do not < # necessarily indicate there are padded tokens in the sequence. < if attn_mask_type == AttnMaskType.causal: < attn_mask_type = AttnMaskType.padding_causal < elif attn_mask_type == AttnMaskType.no_mask: < attn_mask_type = AttnMaskType.padding < core_attn_out = super().forward( < query, < key, < value, < attention_mask, < attn_mask_type=attn_mask_type.name, < **packed_seq_kwargs, < ) < else: < core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs) < < if self.config.apply_rope_fusion and qkv_format == 'bshd': < return core_attn_out.transpose(0, 1) < else: < return core_attn_out < < < if is_te_min_version("1.9.0.dev0"): < < class TEGroupedLinear(te.pytorch.GroupedLinear): < """ < Wrapper for the Transformer-Engine's `GroupedLinear` layer. < < Note that if Megatron's parallel_state has not been initialized < yet, the tp_group passed to TE will be None and must be set later < via set_tensor_parallel_group(). < """ < < def __init__( < self, < num_gemms: int, < input_size: int, < output_size: int, < *, < parallel_mode: str, < config: ModelParallelConfig, < init_method: Callable, < bias: bool, < skip_bias_add: bool, < is_expert: bool = False, < tp_comm_buffer_name: str = None, < ): < self.config = config < < # TE returns a zero length Tensor when bias=False and < # return_bias=True, but we prefer None. So in that case we < # tell TE to not return the bias, and return None < # ourselves. This way our forward always returns two values < # and we don't have to deal with the zero length Tensor. < self.te_return_bias = skip_bias_add and bias < self.is_first_microbatch = True < self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache < < extra_kwargs = _get_extra_te_kwargs(config) < extra_kwargs["ub_name"] = tp_comm_buffer_name < < self.expert_parallel = self.config.expert_model_parallel_size > 1 < if self.expert_parallel: < extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() < < # For MoE models, the comms between TP and EP group is explicitly handled by < # MoE token dispatcher. So we disable comms by making TE agnostic of model parallel. < self.explicit_expert_comm = is_expert and ( < config.tensor_model_parallel_size > 1 or self.expert_parallel < ) < tp_group = get_tensor_model_parallel_group(check_initialized=False) < if self.explicit_expert_comm and config.moe_extended_tp: < tp_size = parallel_state.get_tensor_and_expert_parallel_world_size() < else: < tp_size = parallel_state.get_tensor_model_parallel_world_size() < if self.explicit_expert_comm: < if parallel_mode == "column": < output_size = divide(output_size, tp_size) < elif parallel_mode == "row": < input_size = divide(input_size, tp_size) < parallel_mode = None < tp_size = 1 < tp_group = None < < super().__init__( < num_gemms=num_gemms, < in_features=input_size, < out_features=output_size, < sequence_parallel=self.config.sequence_parallel, < fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, < tp_group=tp_group, < tp_size=tp_size, < get_rng_state_tracker=( < get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None < ), < init_method=condition_init_method(config, init_method), < bias=bias, < return_bias=self.te_return_bias, < parallel_mode=parallel_mode, < **extra_kwargs, < ) < < for param in self.parameters(): < setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) < < def forward(self, x, m_splits): < """Forward.""" < _is_first_microbatch = ( < None if self.disable_parameter_transpose_cache else self.is_first_microbatch < ) < out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) < self.is_first_microbatch = False < < # TE only returns a tuple when return_bias is True, otherwise < # it returns a single Tensor, we always want to return two < # values regardless of the arguments. < if self.te_return_bias: < return out < return out, None < < def _sharded_state_dict_grouped( < self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None < ): < """ < prefix should be module_name to make keys identical to sequetial ones. < """ < sharded_state_dict = {} < full_state_dict = self.state_dict(prefix='', keep_vars=True) < num_global_experts = ( < parallel_state.get_expert_model_parallel_world_size() * self.num_gemms < ) < local_expert_indices_offset = ( < parallel_state.get_expert_model_parallel_rank() * self.num_gemms < ) < ep_axis = len(sharded_offsets) < for gemm_idx in range(self.num_gemms): < state_dict = { < f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'], < f'{gemm_idx}._extra_state': full_state_dict['_extra_state'], < } < if self.use_bias: < state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}'] < sub_sd = make_sharded_tensors_for_checkpoint( < state_dict, < '', < tp_axis_map, < ( < *sharded_offsets, < (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts), < ), < ) < # Remove expert layers indexing from sharded keys < replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix) < sharded_state_dict.update( < { < f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'], < # TODO: TE's GroupedLinear only has one _extra_state for all experts. < # We need sharding or build/merge fn to handle _extra_state correctly. < f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[ < f'{gemm_idx}._extra_state' < ], < } < ) < if self.use_bias: < sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias'] < # Adjust replica ids - replication along DP modulo EP < for k, sh_ten in sharded_state_dict.items(): < replica_id = sh_ten.replica_id < assert ( < len(replica_id) == 3 < ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' < sh_ten.replica_id = ( < *replica_id[:2], < parallel_state.get_data_modulo_expert_parallel_rank(), < ) < return sharded_state_dict < < class TEColumnParallelGroupedLinear(TEGroupedLinear): < """ < Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized < to column-parallel style. < """ < < def __init__( < self, < num_gemms: int, < input_size: int, < output_size: int, < *, < config: ModelParallelConfig, < init_method: Callable, < bias: bool, < skip_bias_add: bool, < is_expert: bool, < tp_comm_buffer_name: str = None, < ): < < super().__init__( < num_gemms=num_gemms, < input_size=input_size, < output_size=output_size, < parallel_mode="column", < config=config, < init_method=condition_init_method(config, init_method), < bias=bias, < skip_bias_add=skip_bias_add, < is_expert=is_expert, < tp_comm_buffer_name=tp_comm_buffer_name, < ) < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < """ < For each gemm, sharding along axis 0, bias sharded. < Assume sharded_offsets[-1] is the expert parallel offset. < """ < tp_axis_map = {} < for gemm_idx in range(self.num_gemms): < tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) < return super()._sharded_state_dict_grouped( < tp_axis_map, prefix, sharded_offsets, metadata < ) < < class TERowParallelGroupedLinear(TEGroupedLinear): < """ < Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized < to row-parallel style. < """ < < def __init__( < self, < num_gemms: int, < input_size: int, < output_size: int, < *, < config: ModelParallelConfig, < init_method: Callable, < bias: bool, < skip_bias_add: bool, < is_expert: bool, < tp_comm_buffer_name: str = None, < ): < < super().__init__( < num_gemms=num_gemms, < input_size=input_size, < output_size=output_size, < parallel_mode="row", < config=config, < init_method=condition_init_method(config, init_method), < bias=bias, < skip_bias_add=skip_bias_add, < is_expert=is_expert, < tp_comm_buffer_name=tp_comm_buffer_name, < ) < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < """ < For each gemm, sharding along axis 1, bias not sharded. < Assume sharded_offsets[-1] is the expert parallel offset. < """ < tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)} < return super()._sharded_state_dict_grouped( < tp_axis_map, prefix, sharded_offsets, metadata < ) < < else: < < TEGroupedLinear = None < TEColumnParallelGroupedLinear = None < TERowParallelGroupedLinear = None < < < class TEDelayedScaling(te.common.recipe.DelayedScaling): < """ < Wrapper for the Transformer-Engine's `DelayedScaling` layer. < """ < < def __init__( < self, < config: ModelParallelConfig, < fp8_format: int, < override_linear_precision: tuple = (False, False, False), < ): < extra_kwargs = _get_extra_te_kwargs(config) < if is_te_min_version("1.6.0.dev0"): < extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention < extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention < if get_te_version() < PkgVersion("1.8.0"): < extra_kwargs["interval"] = config.fp8_interval < elif config.fp8_interval != 1: < warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") < < super().__init__( < margin=config.fp8_margin, < fp8_format=fp8_format, < amax_compute_algo=config.fp8_amax_compute_algo, < amax_history_len=config.fp8_amax_history_len, < override_linear_precision=override_linear_precision, < **extra_kwargs, < ) < < < class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): < """Wraps TransformerEngine's CudaRNGStatesTracker so that it is < interchangeable with Megatron's RNG tracker""" < < def is_initialized(self): < """Checks if the internal RNG state has been set wirth set_states().""" < return self._is_initialized < < def reset(self): < """Reset the internal RNG state.""" < super().reset() < self._is_initialized = False < < def set_states(self, states): < """Set the internal RNG state.""" < super().set_states(states) < self._is_initialized = True < < def add(self, name, seed): < """Track the rng state.""" < super().add(name, seed) < self._is_initialized = True < < < def te_checkpoint( < forward_func, < distribute_saved_activations, < get_rng_state_tracker, < tp_group, < hidden_states, < attention_mask, < context, < context_mask, < rotary_pos_emb, < ): < """Checkpointing with Transformer-Engine.""" < from transformer_engine.pytorch.distributed import checkpoint < < if is_te_min_version("1.5.0"): < return checkpoint( < forward_func, < hidden_states, < attention_mask, < context, < context_mask, < rotary_pos_emb, < distribute_saved_activations=distribute_saved_activations, < get_rng_state_tracker=get_rng_state_tracker, < tp_group=tp_group, < ) < else: < return checkpoint( < forward_func, < distribute_saved_activations, < get_rng_state_tracker, < tp_group, < hidden_states, < attention_mask, < context, < context_mask, < rotary_pos_emb, < ) < < < try: < < from transformer_engine.pytorch.attention import _SplitAlongDim < < SplitAlongDim = _SplitAlongDim.apply < < except ImportError: < < SplitAlongDim = None < < try: < < from transformer_engine.pytorch.cpu_offload import ( < get_cpu_offload_context as _get_cpu_offload_context, < ) < < def get_cpu_offload_context( < enabled, num_layers, model_layers, activation_offloading, weight_offloading < ): < """Get CPU offload context and sync function.""" < if is_te_min_version("1.10.0.dev0"): < context, sync_func = _get_cpu_offload_context( < enabled, num_layers, model_layers, activation_offloading, weight_offloading < ) < else: < context, sync_func = _get_cpu_offload_context( < enabled, num_layers, activation_offloading, weight_offloading < ) < < return context, sync_func < < except ImportError: < < get_cpu_offload_context = None Binary files ./megatron/core/fusions/__pycache__/fused_bias_dropout.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_bias_dropout.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/fused_bias_geglu.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_bias_geglu.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/fused_bias_gelu.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_bias_gelu.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/fused_bias_swiglu.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_bias_swiglu.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/fused_cross_entropy.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_cross_entropy.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/fused_layer_norm.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_layer_norm.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/fused_softmax.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/fused_softmax.cpython-310.pyc differ Binary files ./megatron/core/fusions/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/fusions/__pycache__/__init__.cpython-310.pyc differ diff -rN ./megatron/core/inference/modelopt_support/gpt/model_specs.py ../megatron-lm/megatron/core/inference/modelopt_support/gpt/model_specs.py 3d2 < from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm 6a6 > from megatron.core.transformer.custom_layers.transformer_engine import TEDotProductAttention, TENorm diff -rN ./megatron/core/models/bert/bert_layer_specs.py ../megatron-lm/megatron/core/models/bert/bert_layer_specs.py 13c13 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.transformer.custom_layers.transformer_engine import ( 24c24 < import apex # pylint: disable=unused-import --- > import apex diff -rN ./megatron/core/models/bert/bert_model.py ../megatron-lm/megatron/core/models/bert/bert_model.py 3,4c3,4 < import warnings < from typing import Literal, Optional --- > from importlib.metadata import version > from typing import Dict, Literal, Optional 6a7 > from pkg_resources import packaging 11c12,13 < from megatron.core.models.bert.bert_layer_specs import bert_layer_local_spec --- > from megatron.core.dist_checkpointing.mapping import ShardedStateDict > from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec 22,23d23 < from megatron.core.utils import get_te_version as _get_te_version < from megatron.core.utils import is_te_min_version 27,29c27 < """Included for backwards compatibility.""" < warnings.warn("`get_te_version` will be deprecated in a future release") < return _get_te_version() --- > return packaging.version.Version(version("transformer-engine")) 32d29 < # pylint: disable=line-too-long 94c91,93 < self.attn_mask_dimensions = self._sanity_check_attention_and_get_attn_mask_dimension() --- > self.attn_mask_dimensions = self._santiy_check_attention_and_get_attn_mask_dimension( > transformer_layer_spec > ) 153,154c152,154 < # pylint: disable=line-too-long < def _sanity_check_attention_and_get_attn_mask_dimension(self) -> str: --- > def _santiy_check_attention_and_get_attn_mask_dimension( > self, transformer_layer_spec: ModuleSpec > ) -> str: 157,159c157 < Transformer engine library underwent a lot of change. So we need to change dimensions of < the attention mask depending on the TE version. We also santiy check some arguments. < --- > Transformer engine library underwent a lot of change. So we need to change dimensions of the attention mask depending on the TE version. We also santiy check some arguments. 161,168c159,160 < 2. If we use transformer engine > 1.10 we support all 3 backends with padding mask and [b,1,s,s] < 3. If we use transformer engine >= 1.7 but less than 1.10 < a ) Flash and Fused attention uses padding mask with [b,1,1,s] < b ) Unfused attention works with arbitrary mask with [b,1,s,s] < 4. If we use transformer engine < 1.7 < Flash and fused attention is not supported. Unfused attention will work with padding mask [b,1,s,s] < < Default if you dont set any NVTE_ATTN flag will it will just use the fused path for transformer engine version >= 1.7 and unfused path for other --- > 2. If we use transformer engine < 1.7 (Flash and Fused attention not supported. We use unfused path). Attn mask dimension is [b,1,s,s] > 2. If we use transformer engine >= 1.7 (Flash and fused attention supported with attn mask dimension [b,1,1,s]). Unfused path will use attn mask dimension [b,1,s,s] with attn mask type arbitrary. Default if you dont set any NVTE_ATTN flag will just use unfused path. 171c163 < transformer_layer_spec (ModuleSpec): The transformer layer spec --- > transformer_layer_spec (ModuleSpec): _description_ 174c166 < str: A string showing the format of the attn mask dimensions --- > str: _description_ 176,199c168,175 < attn_mask_dimensions = None < # For local layer spec we just use b1ss < if self.transformer_layer_spec == bert_layer_local_spec: < attn_mask_dimensions = "b1ss" < else: < attn_mask_type = self.transformer_layer_spec.submodules.self_attention.params[ < 'attn_mask_type' < ] < flash_attention_enabled = os.getenv('NVTE_FLASH_ATTN') == '1' < fused_attention_enabled = os.getenv('NVTE_FUSED_ATTN') == '1' < # For TE >= 1.10 (We always use padding mask and use b11s) < if is_te_min_version("1.10.0"): < attn_mask_dimensions = "b11s" < if attn_mask_type != AttnMaskType.padding: < warnings.warn( < f'For TE versions >= 1.10 , flash/fused/unfused support padding mask. Setting attention mask from {attn_mask_type} to padding' < ) < self.transformer_layer_spec.submodules.self_attention.params[ < 'attn_mask_type' < ] = AttnMaskType.padding < # For 1.7 >= TE < 1.10 flash and fused path use padding mask with b11s and unfused path uses arbitrary mask with b1ss < elif is_te_min_version("1.7.0"): < if flash_attention_enabled or fused_attention_enabled: < attn_mask_dimensions = "b11s" --- > attn_mask_dimensions = "b1ss" > if transformer_layer_spec == bert_layer_with_transformer_engine_spec: > if get_te_version() >= packaging.version.Version("1.7.0"): > if os.getenv('NVTE_FLASH_ATTN') == '0' and os.getenv('NVTE_FUSED_ATTN') == '0': > assert ( > transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] > == AttnMaskType.arbitrary > ), "Set env variable NVTE_FLASH_ATTN to 1 or NVTE_FUSED_ATTN to 1 to use a more optimized attention kernal. Currently using unfused attention path. If you want to proceed with this path set AttnMaskType in module spec to be arbitrary" 201,209c177 < if attn_mask_type != AttnMaskType.arbitrary: < warnings.warn( < f'For TE versions >= 1.7 but < 1.10 , unfused path supports only arbitrary mask. Setting attention mask from {attn_mask_type} to arbitray' < ) < self.transformer_layer_spec.submodules.self_attention.params[ < 'attn_mask_type' < ] = AttnMaskType.arbitrary < attn_mask_dimensions = "b1ss" < # For TE < 1.7 we only support unfused attention with b1ss and padding mask --- > attn_mask_dimensions = "b11s" 211,217c179,181 < attn_mask_dimensions = "b1ss" < assert not flash_attention_enabled and not fused_attention_enabled, ( < "Flash and fused attention is not supported with transformer engine version " < "< 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0 or upgrade transformer " < "engine >= 1.7" < ) < --- > assert os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO') == '0' or ( > os.getenv('NVTE_FLASH_ATTN') == '0' and os.getenv('NVTE_FUSED_ATTN') == '0' > ), "Flash and fused attention is not supported with transformer engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0 or upgrade transformer engine >= 1.7 or set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0" 251d214 < """Position ids for bert model""" Binary files ./megatron/core/models/common/embeddings/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/models/common/embeddings/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/models/common/embeddings/__pycache__/language_model_embedding.cpython-310.pyc and ../megatron-lm/megatron/core/models/common/embeddings/__pycache__/language_model_embedding.cpython-310.pyc differ Binary files ./megatron/core/models/common/embeddings/__pycache__/rotary_pos_embedding.cpython-310.pyc and ../megatron-lm/megatron/core/models/common/embeddings/__pycache__/rotary_pos_embedding.cpython-310.pyc differ Binary files ./megatron/core/models/common/language_module/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/models/common/language_module/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/models/common/language_module/__pycache__/language_module.cpython-310.pyc and ../megatron-lm/megatron/core/models/common/language_module/__pycache__/language_module.cpython-310.pyc differ Binary files ./megatron/core/models/common/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/models/common/__pycache__/__init__.cpython-310.pyc differ diff -rN ./megatron/core/models/gpt/gpt_layer_specs.py ../megatron-lm/megatron/core/models/gpt/gpt_layer_specs.py 17c17 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.transformer.custom_layers.transformer_engine import ( 19d18 < TEColumnParallelLinear, 51d49 < fp8: Optional[str] = None, 60d57 < fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None. 66c63 < use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8 --- > use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm 142d138 < fp8: Optional[str] = None, 159,161d154 < elif use_te and fp8: < linear_fc1 = TEColumnParallelLinear < linear_fc2 = TERowParallelLinear Binary files ./megatron/core/models/gpt/__pycache__/gpt_layer_specs.cpython-310.pyc and ../megatron-lm/megatron/core/models/gpt/__pycache__/gpt_layer_specs.cpython-310.pyc differ Binary files ./megatron/core/models/gpt/__pycache__/gpt_model.cpython-310.pyc and ../megatron-lm/megatron/core/models/gpt/__pycache__/gpt_model.cpython-310.pyc differ Binary files ./megatron/core/models/gpt/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/models/gpt/__pycache__/__init__.cpython-310.pyc differ diff -rN ./megatron/core/models/mamba/mamba_layer_specs.py ../megatron-lm/megatron/core/models/mamba/mamba_layer_specs.py 3,7d2 < from megatron.core.extensions.transformer_engine import ( < TEDotProductAttention, < TELayerNormColumnParallelLinear, < TERowParallelLinear, < ) 12a8,12 > from megatron.core.transformer.custom_layers.transformer_engine import ( > TEDotProductAttention, > TELayerNormColumnParallelLinear, > TERowParallelLinear, > ) diff -rN ./megatron/core/models/mamba/mamba_model.py ../megatron-lm/megatron/core/models/mamba/mamba_model.py 24,33c24,27 < max_sequence_length (int): maximum size of sequence. < This is used for positional embedding < pre_process (bool, optional): Include embedding layer < (used with pipeline parallelism). Defaults to True. < mamba_ssm_ngroups (int, optional): Specifies the number of groups to use. < The default value is 8, as in the NVIDIA Mamba2 (pure and hybrid) 8b. < However, in the original Mamba2 paper, the checkpoints use a setting of 1. < Defaults to 8. < hybrid_attention_ratio (float, optional): The target ratio of attention < layers to total layers --- > max_sequence_length (int): maximum size of sequence. This is used for positional embedding > pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True. > mamba_ssm_ngroups (int, optional): Specifies the number of groups to use. The default value is 8, as in the NVIDIA Mamba2 (pure and hybrid) 8b. However, in the original Mamba2 paper, the checkpoints use a setting of 1. Defaults to 8. > hybrid_attention_ratio (float, optional): The target ratio of attention layers to total layers 36,37c30 < post_process (bool, optional): Include an output layer (used with pipeline parallelism). < Defaults to True. --- > post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True. 39,51c32,37 < parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor < parallel ranks. Defaults to True. < share_embeddings_and_output_weights (bool, optional): When True, input embeddings and < output logit weights are shared. Defaults to False. < position_embedding_type (Literal[learned_absolute,rope,none], optional): Position < embedding type. Defaults to 'none'. < rotary_percent (float, optional): Percent of rotary dimension to use for rotary position < embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. < rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless < position_embedding_type is 'rope'. Defaults to 10000. < seq_len_interpolation_factor (Optional[float], optional): scale of linearly < interpolating RoPE for longer sequences. The value must be a float larger than 1.0. < Defaults to None. --- > parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True. > share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False. > position_embedding_type (Literal[learned_absolute,rope,none], optional): Position embedding type. Defaults to 'none'. > rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. > rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000. > seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None. diff -rN ./megatron/core/models/multimodal/llava_spec.py ../megatron-lm/megatron/core/models/multimodal/llava_spec.py 2c2,12 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add > from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec > from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear > from megatron.core.transformer.attention import ( > CrossAttention, > CrossAttentionSubmodules, > SelfAttention, > SelfAttentionSubmodules, > ) > from megatron.core.transformer.custom_layers.transformer_engine import ( > TEColumnParallelLinear, 8,11d17 < from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add < from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec < from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear < from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules 14a21 > from megatron.core.transformer.mlp import MLP, MLPSubmodules 15a23,27 > from megatron.core.transformer.transformer_block import ( > TransformerBlockSubmodules, > get_num_layers_to_build, > ) > from megatron.core.transformer.transformer_config import TransformerConfig 19c31 < import apex # pylint: disable=unused-import --- > import apex Binary files ./megatron/core/models/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/models/__pycache__/__init__.cpython-310.pyc differ diff -rN ./megatron/core/models/retro/config.py ../megatron-lm/megatron/core/models/retro/config.py 5a6 > import types 6a8,10 > from importlib.metadata import version > > from pkg_resources import packaging 9d12 < from megatron.core.utils import is_te_min_version 60d62 < # pylint: disable=line-too-long 67c69,70 < if is_te_min_version("1.3"): --- > te_version = packaging.version.Version(version("transformer-engine")) > if te_version >= packaging.version.Version("1.3"): diff -rN ./megatron/core/models/retro/decoder_spec.py ../megatron-lm/megatron/core/models/retro/decoder_spec.py 28c28 < import apex # pylint: disable=unused-import --- > import apex 43c43 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.transformer.custom_layers.transformer_engine import ( 67,68c67 < encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for < the first Retro decoder layer. --- > encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for the first Retro decoder layer. 101,102c100 < encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided < for the first Retro decoder layer. --- > encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided for the first Retro decoder layer. 129,134c127,129 < - The retro decoder block consists of interleaved GPT layers < and customized Retro decoder layers. < - The Retro decoder layers are spaced three layers apart, < and start on layer 6 or 9 (depending on the total number of layers). < - The first decoder layer instantiates an encoder block, < and it therefore passes in an encoder_block_spec. --- > - The retro decoder block consists of interleaved GPT layers and customized Retro decoder layers. > - The Retro decoder layers are spaced three layers apart, and start on layer 6 or 9 (depending on the total number of layers). > - The first decoder layer instantiates an encoder block, and it therefore passes in an encoder_block_spec. diff -rN ./megatron/core/models/retro/encoder_spec.py ../megatron-lm/megatron/core/models/retro/encoder_spec.py 24c24 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.transformer.custom_layers.transformer_engine import ( 36c36 < import apex # pylint: disable=unused-import --- > import apex Binary files ./megatron/core/models/retro/__pycache__/base_attention.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/base_attention.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/config.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/config.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/decoder_attention.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/decoder_attention.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/decoder_spec.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/decoder_spec.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/encoder_attention.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/encoder_attention.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/encoder_spec.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/encoder_spec.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/model.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/model.cpython-310.pyc differ Binary files ./megatron/core/models/retro/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/core/models/retro/__pycache__/utils.cpython-310.pyc differ diff -rN ./megatron/core/models/T5/t5_model.py ../megatron-lm/megatron/core/models/T5/t5_model.py 13d12 < from megatron.core.transformer.enums import ModelType 160,161d158 < < self.model_type = ModelType.encoder_and_decoder diff -rN ./megatron/core/models/T5/t5_spec.py ../megatron-lm/megatron/core/models/T5/t5_spec.py 15c15,19 < from megatron.core.transformer.transformer_block import TransformerBlockSubmodules --- > from megatron.core.transformer.transformer_block import ( > TransformerBlockSubmodules, > get_num_layers_to_build, > ) > from megatron.core.transformer.transformer_config import TransformerConfig 19c23 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.transformer.custom_layers.transformer_engine import ( 32c36 < import apex # pylint: disable=unused-import --- > import apex 55c59 < params={"attn_mask_type": AttnMaskType.arbitrary}, --- > params={"attn_mask_type": AttnMaskType.padding}, 97d100 < params={"attn_mask_type": AttnMaskType.arbitrary}, 126c129 < params={"attn_mask_type": AttnMaskType.arbitrary}, --- > params={"attn_mask_type": AttnMaskType.padding}, 174d176 < params={"attn_mask_type": AttnMaskType.arbitrary}, diff -rN ./megatron/core/models/vision/clip_vit_model.py ../megatron-lm/megatron/core/models/vision/clip_vit_model.py 8d7 < from megatron.core.extensions.transformer_engine import TENorm 9a9 > from megatron.core.transformer.custom_layers.transformer_engine import TENorm 92,95c92,93 < # TODO: Follow-up changes will make pre and post_process configurable. < # They are needed for supporting pipeline parallelism. < # Note: a final layer norm and/or linear layer present in some implementations < # are omitted here. They can be added separately where needed. --- > # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting pipeline parallelism. > # Note: a final layer norm and/or linear layer present in some implementations are omitted here. They can be added separately where needed. 140,141c138,140 < x = x.contiguous() < # contiguous() call required as `permute` can sparsify the tensor and this breaks pipelining --- > x = ( > x.contiguous() > ) # contiguous() call required as `permute` can sparsify the tensor and this breaks pipelining diff -rN ./megatron/core/models/vision/vit_layer_specs.py ../megatron-lm/megatron/core/models/vision/vit_layer_specs.py 3c3,6 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add > from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear > from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules > from megatron.core.transformer.custom_layers.transformer_engine import ( 8,10d10 < from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add < from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear < from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules 19c19 < import apex # pylint: disable=unused-import --- > import apex diff -rN ./megatron/core/optimizer/distrib_optimizer.py ../megatron-lm/megatron/core/optimizer/distrib_optimizer.py 7d6 < import warnings 34a34 > from ..dist_checkpointing.optimizer import get_param_id_to_sharded_param_map 36,38c36 < from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets < from ..transformer.module import MegatronModule < from ..utils import is_float8tensor --- > from ..distributed import ParamAndGradBuffer, shard_buffer 40,44c38 < from .optimizer import ( < MixedPrecisionOptimizer, < _multi_tensor_copy_this_to_that, < _zero_grad_group_helper, < ) --- > from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper 47,54d40 < try: < # This will be used when "--fp8-param-gather" is enabled. < # When BF16/FP16 parameters don't exist, we need to cast the FP32 main parameters to < # FP8 directly in the optimizer. < from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 < except: < pass < 159c145 < def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, bucket_index: int): --- > def _build_model_gbuf_range(cls, param_and_grad_buffer: ParamAndGradBuffer, bucket_index: int): 206c192 < def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer): --- > def _build_gbuf_range_map(cls, param_and_grad_buffer: ParamAndGradBuffer): 216c202 < param_and_grad_buffer (_ParamAndGradBuffer): buffer to build mapping for. --- > param_and_grad_buffer (ParamAndGradBuffer): buffer to build mapping for. 237,240c223,225 < assert param not in param_gbuf_map, ( < "Param should not be in param_gbuf_map; each param only belongs " < "to a single bucket." < ) --- > assert ( > param not in param_gbuf_map > ), "Param should not be in param_gbuf_map; each param only belongs to a single bucket" 351,369c336 < < # If we use FP8 params to initialize FP32 main params (compared to using the < # bf16/fp16 params to initialize the main params), there will be a loss of < # precision at the beginning of training (this problem will not occur if the < # training is long enough or if the main params are loaded from a checkpoint). < if is_float8tensor(model_param) and hasattr( < model_param, 'get_high_precision_init_val' < ): < shard_main_param = ( < model_param.get_high_precision_init_val() < .view(-1)[param_range.start : param_range.end] < .clone() < .to(shard_model_param.device) < .float() < ) < model_param.clear_high_precision_init_val() < else: < shard_main_param = shard_model_param.clone().float() < --- > shard_main_param = shard_model_param.clone().float() 425,426c392 < model_chunks: List[MegatronModule], < per_model_buffers: Dict[int, List[_ParamAndGradBuffer]], --- > per_model_buffers: Dict[int, List[ParamAndGradBuffer]], 429a396 > overlap_param_gather_with_optimizer_step: bool = False, 448d414 < model_chunks (List[MegatronModule]): list of model chunks. 459a426,427 > overlap_param_gather_with_optimizer_step (bool, optional): if true, overlap parameter > all-gather with optimizer step. Defaults to False. 470,473d437 < self.model_chunks = model_chunks < self.ddp_config = self.model_chunks[0].ddp_config < for model_chunk in self.model_chunks: < assert self.ddp_config == model_chunk.ddp_config 486d449 < 493,497d455 < < self.per_model_bucket_groups = {} < for model_idx, buffers in self.per_model_buffers.items(): < self.per_model_bucket_groups[model_idx] = partition_buckets(buffers) < 535a494,528 > # Now construct data structures to manage all-gather handles. > self.all_gather_handles = [] > self.all_gather_handle_index_to_bucket_index_map = [] > self.model_index_to_all_gather_handle_index_map = {} > self.all_gather_handle_indices = [] > self.param_to_all_gather_handle_index_map = {} > > self.pbuf_view_items = self._get_model_param_buffer_dp_views() > for gbuf_index, dtype, bucket_index, _, _ in self.pbuf_view_items: > self.all_gather_handle_index_to_bucket_index_map.append( > (gbuf_index, dtype, bucket_index) > ) > all_gather_handle_index = len(self.all_gather_handle_index_to_bucket_index_map) - 1 > self.all_gather_handles.append(None) > > # Store all all_gather_handle_indices. > model_idx = self.gbuf_idx_to_model_idx_map[gbuf_index] > if model_idx not in self.model_index_to_all_gather_handle_index_map: > self.model_index_to_all_gather_handle_index_map[model_idx] = [] > self.model_index_to_all_gather_handle_index_map[model_idx].append( > all_gather_handle_index > ) > > for param in self.buffers[gbuf_index].buckets[bucket_index].params_list: > self.param_to_all_gather_handle_index_map[param] = all_gather_handle_index > self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map) > > self.overlap_param_gather = self.config.overlap_param_gather > self.overlap_param_gather_with_optimizer_step = overlap_param_gather_with_optimizer_step > self.remove_pre_hook_handle = None > if self.overlap_param_gather: > self.enable_pre_hook() > > self.update_successful = False > 546,548c539,541 < warnings.warn( < "`DistributedOptimizer.enable_pre_hook` will be deprecated in a future release. " < "Use `DistributedDataParallel.enable_forward_pre_hook` directly." --- > assert self.remove_pre_hook_handle is None > self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook( > self._make_forward_pre_hook() 550,551d542 < for model_chunk in self.model_chunks: < model_chunk.enable_forward_pre_hook() 557,562c548,554 < warnings.warn( < "`DistributedOptimizer.disable_pre_hook` will be deprecated in a future release. " < "Use `DistributedDataParallel.disable_forward_pre_hook` directly." < ) < for model_chunk in self.model_chunks: < model_chunk.disable_forward_pre_hook() --- > assert self.remove_pre_hook_handle is not None > self.remove_pre_hook_handle.remove() > self.remove_pre_hook_handle = None > > # Make sure all-gathers are completed as needed. > self._reset_metadata_and_sync_gather_all_model_params(force_sync=True) > self.update_successful = False 876,878c868,870 < # Copy this bucket's collected all-gather tensors into the right place < # in the tensor for the buffer. The tensor for the buffer gets rid of < # the padding between buckets. --- > # Copy this bucket's collected all-gather tensors into the right place in the > # tensor for the buffer. The tensor for the buffer gets rid of the padding > # between buckets. 1003,1006d994 < key = ( < f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}' < f'.{per_bucket_key}' < ) 1008c996,1000 < key, state[per_bucket_key], (1,), (0,), replica_id=data_parallel_rank --- > f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.{per_bucket_key}', > state[per_bucket_key], > (1,), > (0,), > replica_id=data_parallel_rank, 1019,1022c1011 < sharded_bucket_key = ( < f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}' < f'.gbuf_idx_{gbuf_idx}.dtype_{dtype}.bucket_idx_{bucket_idx}' < ) --- > sharded_bucket_key = f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.gbuf_idx_{gbuf_idx}.dtype_{dtype}.bucket_idx_{bucket_idx}' 1123,1126c1112 < < # Not stored in the checkpoint, used only to identify params in < # `sharded_param_state_fs_model_space`. < param_idx = 0 --- > param_idx = 0 # this is not stored in the checkpoint, used only to identify params in `sharded_param_state_fs_model_space` 1138,1139c1124 < # Match optimizer parameter with model ShardedTensor (or < # ShardedTensorFactory). --- > # Match optimizer parameter with model ShardedTensor (or ShardedTensorFactory) 1147c1132 < # Set DP corresponding replica_id coordinate to 0. --- > # Set DP corresponding replica_id coordinate to 0 1153,1154c1138 < # Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer < # params. --- > # Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer params 1259,1260c1243 < """Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank, < using the legacy checkpoint format as described below. --- > """Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank, using the legacy checkpoint format as described below. 1329,1330c1312 < # Pad world_tensor to gbuf_world_numel. Don't pad at the front, < # pad at the back. --- > # Pad world_tensor to gbuf_world_numel. Don't pad at the front, pad at the back. 1396,1399d1377 < if data_parallel_rank == 0: < # Do nothing if "--fp8-param-gather" is not used. < self.split_state_dict_if_needed(state_dict) < 1439,1440c1417 < # Pad world_tensor to gbuf_world_numel. Don't pad at the front, < # pad at the back. --- > # Pad world_tensor to gbuf_world_numel. Don't pad at the front, pad at the back. 1481,1613d1457 < def split_state_dict_if_needed(self, state_dict): < """ < When "--fp8-param-gather" is disabled, weights and biases are stored in the same < `ParamAndGradBuffer`. So, when saving a checkpoint, the optimizer's main parameters are < saved in a single continuous tensor (this also applies to "exp_avg" and "exp_avg_sq"). < < However, when "--fp8-param-gather" is enabled, weights(in fp8 dtype) and biases(in bf16/fp16 < dtype) are stored in separate `ParamAndGradBuffer`. Therefore, when we enabled < "--fp8-param-gather", and want to load a checkpoint saved without "--fp8-param-gather", we < need to split the weights(fp8) and biases(bf16/fp16) in the static_dict into two separate < tensors. < """ < # Skip if there is no fp8 buffers. < fp8_gbuf_indices = [] < for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges): < for dtype, _ in gbuf_range_maps.items(): < if is_float8tensor(self.buffers[gbuf_idx].params[0]): < fp8_gbuf_indices.append(gbuf_idx) < if len(fp8_gbuf_indices) == 0: < return < < dtype_to_gbuf_idx = {} < for key in state_dict.keys(): < if key != 'buckets_coalesced': < for dtype in state_dict[key].keys(): < assert dtype not in dtype_to_gbuf_idx < if dtype[0] == torch.uint8: < # If the `state_dict`` already contains a torch.uint8 buffer, we assumed < # that the fp8 weights and fp16/bf16 biases in the checkpoint are already < # separated. In this case, no action is required, so we can return directly. < return < dtype_to_gbuf_idx[dtype] = key < < # 1. Replace the gbuf_idx in the checkpoint with the new gbuf_idx. < # 2. Copy the non-tensor data (i.e., the "buckets_coalesced") to `new_state_dict`. < new_state_dict = {'buckets_coalesced': state_dict['buckets_coalesced']} < for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges): < for dtype, _ in gbuf_range_maps.items(): < if not is_float8tensor(self.buffers[gbuf_idx].params[0]): < new_state_dict[gbuf_idx] = state_dict[dtype_to_gbuf_idx[dtype]] < < for fp8_gbuf_idx in fp8_gbuf_indices: < # Note that `self.buffers[fp8_gbuf_idx].params[0].dtype` is the dummy dtype of < # `Float8Tensor`, not torch.uint8. < non_fp8_param_and_grad_dtype = ( < self.buffers[fp8_gbuf_idx].params[0].dtype, < self.buffers[fp8_gbuf_idx].grad_dtype, < ) < < # Iterate through all buffers to find the one that needs to be split. < non_fp8_gbuf_idx = None < for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges): < for dtype, _ in gbuf_range_maps.items(): < if dtype == non_fp8_param_and_grad_dtype: < non_fp8_gbuf_idx = gbuf_idx < assert non_fp8_gbuf_idx is not None < < # We need the fp8_flags to determine the order of weight (fp8) and bias (fp16/bf16) in < # the buffer. < index_to_fp8_map = {} < for index in self.buffers[fp8_gbuf_idx].param_indices: < assert index not in index_to_fp8_map < index_to_fp8_map[index] = True < for index in self.buffers[non_fp8_gbuf_idx].param_indices: < assert index not in index_to_fp8_map < index_to_fp8_map[index] = False < param_indices = ( < self.buffers[fp8_gbuf_idx].param_indices < + self.buffers[non_fp8_gbuf_idx].param_indices < ) < assert min(param_indices) == 0 < assert max(param_indices) == len(param_indices) - 1 < fp8_flags = [] < for i in range(len(param_indices)): < fp8_flag.append(index_to_fp8_map[i]) < < fp8_buffer = self.buffers[fp8_gbuf_idx] < non_fp8_buffer = self.buffers[non_fp8_gbuf_idx] < < fp8_idx = len(fp8_buffer.params) - 1 < non_fp8_idx = len(non_fp8_buffer.params) - 1 < offsets, fp8_offsets, non_fp8_offsets = [0], [0], [0] < < # Because the parameters in `ParamAndGradBuffer` are traversed in reverse order, the < # flag here also needs to be traversed in reverse order. < for fp8_flag in fp8_flags[::-1]: < if fp8_flag: < numel = fp8_buffer.params[fp8_idx].nelement() < fp8_idx -= 1 < offsets.append(offsets[-1] + numel) < fp8_offsets.append(fp8_offsets[-1] + numel) < else: < numel = non_fp8_buffer.params[non_fp8_idx].nelement() < non_fp8_idx -= 1 < offsets.append(offsets[-1] + numel) < non_fp8_offsets.append(non_fp8_offsets[-1] + numel) < < # Split the target buffer into two separate buffers. < fp8_state_dict, non_fp8_state_dict = {}, {} < for key in ['param', 'exp_avg', 'exp_avg_sq']: < tensor = state_dict[non_fp8_gbuf_idx][non_fp8_param_and_grad_dtype][key] < fp8_tensor = torch.empty([fp8_offsets[-1]], dtype=tensor.dtype) < non_fp8_tensor = torch.empty([non_fp8_offsets[-1]], dtype=tensor.dtype) < < fp8_idx, non_fp8_idx = 0, 0 < for i in range(len(offsets) - 1): < if fp8_flags[-(i + 1)]: < fp8_tensor[fp8_offsets[fp8_idx] : fp8_offsets[fp8_idx + 1]].copy_( < tensor[offsets[i] : offsets[i + 1]] < ) < fp8_idx += 1 < else: < non_fp8_tensor[ < non_fp8_offsets[non_fp8_idx] : non_fp8_offsets[non_fp8_idx + 1] < ].copy_(tensor[offsets[i] : offsets[i + 1]]) < non_fp8_idx += 1 < < fp8_state_dict[key] = fp8_tensor < non_fp8_state_dict[key] = non_fp8_tensor < < fp8_state_dict['numel_unpadded'] = fp8_offsets[-1] < non_fp8_state_dict['numel_unpadded'] = non_fp8_offsets[-1] < < # Add the two separate buffers into `new_state_dict`. < new_state_dict[fp8_gbuf_idx] = {} < new_state_dict[fp8_gbuf_idx][(torch.uint8, fp8_buffer.grad_dtype)] = fp8_state_dict < new_state_dict[non_fp8_gbuf_idx][non_fp8_param_and_grad_dtype] = non_fp8_state_dict < < # Inplace update state_dict < state_dict.clear() < for key, value in new_state_dict.items(): < state_dict[key] = value < 1647a1492,1664 > # If overlapping param all-gather with forward compute, launch all-gather > # for first accessed bucket here before forward compute is initiated. > # The all-gather for the next bucket will be launched in the forward > # pre-hook when this all-gather finishes (to ensure that the communication > # kernels don't head-of-line block the compute kernels since we run with > # CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism). > # If aligning param all-gather across pipeline stages, all-gather is dispatched > # by start_param_sync calls in core/pipeline_parallelism/schedules.py. > # If overlapping param all-gather with optimizer step, then all-gather has > # already been dispatched in optimizer step. > skip_dispatch = ( > self.config.align_param_gather or self.overlap_param_gather_with_optimizer_step > ) > if self.overlap_param_gather and not skip_dispatch: > self._dispatch_gather_model_params(all_gather_handle_index=0) > > def _get_model_param_buffer_dp_views(self): > """ > Get shard views of each of the param buffers. > > In this nested list, the top level is grouped by the virtual model > index and the buffer's data type. The sub-level is a list of > shards of that buffer, where each shard in the list represents > a contiguous view of the buffer, that is owned by a data-parallel > rank. The shard boundary does not respect parameter boundaries, and > so the elements of some parameters are split across data parallel > ranks. > > Additionally, return references to the entire buffers, for use > in _all_gather_base. > """ > > # Buffer views. > # Add in reverse order in each model chunk since buckets start from the end of the model but we want > # all-gathers to run first for the start of the model (same order as forward pass). > # We keep the view_items in model chunk order since we want to still first run all_gather and > # all_gather_handle.wait() for the first model chunk. > # In all cases, we want all_gather and all_gather_handle.wait() to be called in the same order, > # and all_gather_handle.wait() needs to be called just before the corresponding forward pass. > view_items = [] > for gbuf_index, buffer in enumerate(self.buffers): > view_items_per_model_chunk = [] > dtype = self.buffers[gbuf_index].param_dtype > for bucket_index, bucket in enumerate(buffer.buckets): > data_parallel_world_size = torch.distributed.get_world_size( > self.data_parallel_group > ) > buf_views = shard_buffer(bucket.param_data, data_parallel_world_size) > view_items_per_model_chunk.insert( > 0, (gbuf_index, dtype, bucket_index, bucket.param_data, buf_views) > ) > view_items.extend(view_items_per_model_chunk) > > return view_items > > def _dispatch_gather_model_params(self, all_gather_handle_index: int, force_sync: bool = False): > """ > All-gather updated model params. > > When using the distributed optimizer, the params are already laid out in a contiguous > buffer (see mcore/distributed/param_and_grad_buffer.py for details), and so the > all-gather will put the results in the right region of memory. > """ > async_op = self.overlap_param_gather and not force_sync > if self.update_successful: > data_parallel_group = self.data_parallel_group > data_parallel_rank = torch.distributed.get_rank(data_parallel_group) > > # All-gather updated main params. > # All param_buf views are guaranteed to have the same number of elements > # across all data-parallel ranks, due to padding done in > # param_and_grad_buffer.py). Thus, all sub-views will have consistent > # start / end indexes across data-parallel ranks. > (gbuf_index, dtype, bucket_index, pbuf, pbuf_views) = self.pbuf_view_items[ > all_gather_handle_index > ] > assert all_gather_handle_index < len(self.all_gather_handles) > all_gather_handle = torch.distributed._all_gather_base( > pbuf, pbuf_views[data_parallel_rank], group=data_parallel_group, async_op=async_op > ) > self.all_gather_handles[all_gather_handle_index] = all_gather_handle > assert self.all_gather_handle_index_to_bucket_index_map[all_gather_handle_index] == ( > gbuf_index, > dtype, > bucket_index, > ) > > def _make_forward_pre_hook(self): > """ > Create a forward pre-hook to wait on all-gather handles when necessary (i.e., > when a module uses a parameter in a bucket with a still incomplete all-gather) > and then copy the results from the param_buffer into model_params. > """ > > def hook(module, *unused): > assert ( > self.overlap_param_gather > ), "Should use pre-hook only when overlap_param_gather is True" > > # Make sure all parameters in this module have been all-gathered as necessary. > for param in module.parameters(recurse=False): > # Skip parameters that don't require grad. > if not param.requires_grad: > continue > > # Some params might be handled in another DistributedOptimizer instance; for > # example, we use separate DistributedOptimizer instances for expert and > # non-expert params. > if param in self.param_to_all_gather_handle_index_map: > all_gather_handle_index = self.param_to_all_gather_handle_index_map[param] > # If aligning param all-gather across pipeline stages, all-gather is dispatched > # by start_param_sync calls in core/pipeline_parallelism/schedules.py. > # If overlapping param all-gather with optimizer step, then all-gather has > # already been dispatched in optimizer step. > skip_dispatch = ( > self.config.align_param_gather > or self.overlap_param_gather_with_optimizer_step > ) > self._finish_param_sync_helper( > all_gather_handle_index, skip_dispatch=skip_dispatch > ) > > return hook > > def start_param_sync(self, model_index: int, *unused, force_dispatch: bool = False): > """ > Starts all necessary param syncs for the model_index'th model chunk. > > Args: > model_index (int): index of model chunk to synchronize params. > force_dispatch (bool, optional): force dispatch regardless of other settings. > """ > if model_index not in self.model_index_to_all_gather_handle_index_map: > return > > if self.overlap_param_gather_with_optimizer_step and not force_dispatch: > return > > # If overlapping param AG with optimizer step, AG has already been dispatched. > if self.update_successful: > all_gather_handle_indices = self.model_index_to_all_gather_handle_index_map[model_index] > with torch.distributed._coalescing_manager( > group=self.data_parallel_group, async_ops=self.overlap_param_gather > ) as cm: > for all_gather_handle_index in all_gather_handle_indices: > self._dispatch_gather_model_params(all_gather_handle_index) > if self.overlap_param_gather: > for all_gather_handle_index in all_gather_handle_indices: > self.all_gather_handles[all_gather_handle_index] = cm > > def _finish_param_sync_helper(self, all_gather_handle_index: int, skip_dispatch: bool = False): > """ > Waits on all_gather_handle if necessary, then dispatches the next all-gather > as necessary. > """ > > # First check if there is an outstanding all-gather handle for this param. > # If so, wait on the handle to ensure the communication is finished. > assert all_gather_handle_index < len(self.all_gather_handles) > all_gather_handle = self.all_gather_handles[all_gather_handle_index] > if all_gather_handle is not None: > all_gather_handle.wait() > self.all_gather_handles[all_gather_handle_index] = None > > # Launch the all-gather for the next bucket now. > # We can't pre-launch all-gathers for all buckets at once since we don't > # want to head-of-line block the compute kernels with communication kernels > # (since we run with CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence > # parallelism). > next_all_gather_handle_index = all_gather_handle_index + 1 > if next_all_gather_handle_index < self.num_all_gather_handles and not skip_dispatch: > self._dispatch_gather_model_params(next_all_gather_handle_index) > 1723,1742c1740 < if is_float8tensor(model_param): < # 1. When "--fp8-param-gather" is disabled, the main param is first cast to < # BF16/FP16, and then cast to FP8, so the amax_history is calculated < # using BF16/FP16 param. < # 2. When "--fp8-param-gather" is enabled, we can cast the FP32 main param < # to FP8 directly, which results in slightly different results with < # higher speed. In theory, this does not affect convergence. < # TODO: The following code maintains the logic of the point-1 above. It can < # be deleted if it is not necessary. < shard_main_param = shard_main_param.to(model_param.dtype) < < cast_to_fp8( < shard_main_param.view(1, -1), < model_param._fp8_meta['scaling_fwd'], < model_param._fp8_meta_index, < model_param._fp8_dtype, < out=shard_model_param.view(1, -1), < ) < else: < shard_model_param.data.copy_(shard_main_param) --- > shard_model_param.data.copy_(shard_main_param) 1773c1771 < def _update_fp8_scale_inv_and_amax(self): --- > def _reset_metadata_and_sync_gather_all_model_params(self, force_sync: bool): 1775,1776c1773 < If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their < `amax_history`. --- > Reset metadata needed to track results of all-gathers. 1778,1813c1775,1782 < amaxes = [] < scales = [] < scale_invs = [] < # Iterate over all parameters inside this optimizer to find FP8 parameters. < for buffer in self.buffers: < for bucket in buffer.buckets: < for param in bucket.params_list: < if is_float8tensor(param): < fp8_meta = param._fp8_meta['scaling_fwd'] < fp8_meta_index = param._fp8_meta_index < amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) < scales.append(fp8_meta.scale[fp8_meta_index].view(1)) < scale_invs.append(param._scale_inv.view(1)) < # Reset transpose cache < param._reset_caches() < < # If there is no FP8 parameters, skip all operations. < if len(scales) > 0: < dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') < < # Update scaling factors. < packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device) < packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))] < _multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf) < torch.reciprocal(packed_scales, out=packed_scales) < _multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf) < < # Reduce amaxes. < # Note: Assume each param has a separate amax. < packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device) < packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] < _multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf) < torch.distributed.all_reduce( < packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.data_parallel_group < ) < _multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf) --- > self.all_gather_handles = [None for _ in range(len(self.all_gather_handles))] > > # Launch synchronous all-gather if --overlap-param-gather is turned on or if force_sync > # is explicitly set to True (e.g., if we are going to turn off all-gather overlapping for > # validation / test iterations). > if not self.overlap_param_gather or force_sync: > for all_gather_handle_index in range(len(self.all_gather_handles)): > self._dispatch_gather_model_params(all_gather_handle_index, force_sync=force_sync) 1821,1824c1790 < update_successful = super().step_with_ready_grads() < < # If there is no FP8 parameters, this will do nothing. < self._update_fp8_scale_inv_and_amax() --- > self.update_successful = super().step_with_ready_grads() 1831,1835c1797,1800 < # the first all-gather is launched asynchronously in the next optimizer.zero_grad() < # call and subsequent all-gathers are launched in the forward pre-hook. < if not self.ddp_config.overlap_param_gather: < for model_chunk in self.model_chunks: < model_chunk.start_param_sync() --- > # call to _gather_all_model_params is a no-op: the first all-gather is launched > # asynchronously in the next optimizer.zero_grad() call and subsequent all-gathers > # are launched in the forward pre-hook. > self._reset_metadata_and_sync_gather_all_model_params(force_sync=False) 1839c1804 < return update_successful --- > return self.update_successful diff -rN ./megatron/core/optimizer/__init__.py ../megatron-lm/megatron/core/optimizer/__init__.py 21,23c21,22 < # Apex's FusedAdam is a drop-in replacement for torch's AdamW. < # pylint: disable-next=line-too-long. < # See https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16. --- > ## apex's FusedAdam is a drop-in replacement for torch's AdamW > ## see https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16 28c27 < from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer --- > from ..distributed import ParamAndGradBuffer 111,112c110 < # For input/embedding and output layer: embedding.word_embeddings.weight / < # output_layer.weight. --- > # For input/embedding and output layer: embedding.word_embeddings.weight / output_layer.weight. 194c192 < ) -> Tuple[List[Dict], Dict[int, List[_ParamAndGradBuffer]]]: --- > ) -> Tuple[List[Dict], Dict[int, ParamAndGradBuffer]]: 237d234 < model_chunks: List[MegatronModule], 239c236 < per_model_buffers: Optional[Dict[int, List[_ParamAndGradBuffer]]] = None, --- > per_model_buffers: Optional[Dict[int, List[ParamAndGradBuffer]]] = None, 243a241 > overlap_param_gather_with_optimizer_step: bool = False, 249d246 < model_chunks (list): list of model chunks. 257a255,256 > overlap_param_gather_with_optimizer_step (bool, optional): if true, overlap parameter > all-gather with optimizer step if using distributed optimizer. Defaults to False. 323d321 < model_chunks=model_chunks, 327a326 > overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, 391,394d389 < for model_chunk in dense_model_chunks: < model_chunk.overlap_param_gather_with_optimizer_step = ( < overlap_param_gather_with_optimizer_step < ) 398d392 < model_chunks=dense_model_chunks, 406a401 > overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, 427d421 < model_chunks=model_chunks, diff -rN ./megatron/core/optimizer/optimizer_config.py ../megatron-lm/megatron/core/optimizer/optimizer_config.py 98,100c98 < """If true, overlap grad reduce-scatter with backward compute in distributed optimizer. < NOTE: This parameter will be deprecated in a future release. Use `overlap_grad_reduce` < in `megatron/core/distributed/distributed_data_parallel_config.py` instead.""" --- > """If true, overlap grad reduce-scatter with backward compute in distributed optimizer.""" 103,105c101 < """If true, overlap param all-gather with forward compute in distributed optimizer. < NOTE: This parameter will be deprecated in a future release. Use `overlap_param_gather` < in `megatron/core/distributed/distributed_data_parallel_config.py` instead.""" --- > """If true, overlap param all-gather with forward compute in distributed optimizer.""" 108a105,109 > > align_param_gather: bool = False > """If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each > PP stage will independently launch as needed. > """ diff -rN ./megatron/core/optimizer/optimizer.py ../megatron-lm/megatron/core/optimizer/optimizer.py 7d6 < import warnings 16c15 < from transformer_engine.pytorch.optimizers import multi_tensor_applier --- > from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_scale 257,258c256,257 < is_loading (bool, optional): flag indicating whether the state dict will be < used to save or load the optimizer state. Defaults to False. --- > is_loading (bool, optional): flag indicating whether the state dict will be used to save or load the optimizer state. > Defaults to False. 882d880 < self.model_chunks = [] 884,888c882 < for optimizer in chained_optimizers: < if hasattr(optimizer, 'model_chunks'): < for model_chunk in optimizer.model_chunks: < if model_chunk not in self.model_chunks: < self.model_chunks.append(model_chunk) --- > for optimizer in chained_optimizers[1:]: 962,963c956 < assert len(optimizer.model_chunks) == 1 < optimizer.model_chunks[0].start_param_sync(force_dispatch=True) --- > optimizer.start_param_sync(model_index=0, force_dispatch=True) 969,974c962,971 < warnings.warn( < "`ChainedOptimizer.disable_pre_hook` will be deprecated in a future release. " < "Use `DistributedDataParallel.disable_forward_pre_hook` directly." < ) < for model_chunk in self.model_chunks: < model_chunk.disable_forward_pre_hook() --- > for optimizer in self.chained_optimizers: > if ( > not optimizer.config.use_distributed_optimizer > or not optimizer.config.overlap_param_gather > ): > raise ValueError( > "disable_pre_hook should only be called with 'use_distributed_optimizer' " > "and 'overlap_param_gather' both enabled." > ) > optimizer.disable_pre_hook() 978,983c975,984 < warnings.warn( < "`ChainedOptimizer.enable_pre_hook` will be deprecated in a future release. " < "Use `DistributedDataParallel.enable_forward_pre_hook` directly." < ) < for model_chunk in self.model_chunks: < model_chunk.enable_forward_pre_hook() --- > for optimizer in self.chained_optimizers: > if ( > not optimizer.config.use_distributed_optimizer > or not optimizer.config.overlap_param_gather > ): > raise ValueError( > "enable_pre_hook should only be called with 'use_distributed_optimizer' " > "and 'overlap_param_gather' both enabled." > ) > optimizer.enable_pre_hook() Binary files ./megatron/core/optimizer/__pycache__/clip_grads.cpython-310.pyc and ../megatron-lm/megatron/core/optimizer/__pycache__/clip_grads.cpython-310.pyc differ Binary files ./megatron/core/optimizer/__pycache__/distrib_optimizer.cpython-310.pyc and ../megatron-lm/megatron/core/optimizer/__pycache__/distrib_optimizer.cpython-310.pyc differ Binary files ./megatron/core/optimizer/__pycache__/grad_scaler.cpython-310.pyc and ../megatron-lm/megatron/core/optimizer/__pycache__/grad_scaler.cpython-310.pyc differ Binary files ./megatron/core/optimizer/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/optimizer/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/optimizer/__pycache__/optimizer_config.cpython-310.pyc and ../megatron-lm/megatron/core/optimizer/__pycache__/optimizer_config.cpython-310.pyc differ Binary files ./megatron/core/optimizer/__pycache__/optimizer.cpython-310.pyc and ../megatron-lm/megatron/core/optimizer/__pycache__/optimizer.cpython-310.pyc differ diff -rN ./megatron/core/package_info.py ../megatron-lm/megatron/core/package_info.py 7c7 < PRE_RELEASE = '' --- > PRE_RELEASE = 'rc0' diff -rN ./megatron/core/parallel_state.py ../megatron-lm/megatron/core/parallel_state.py 230,231d229 < """A class for generating rank groups for different modes of parallelism.""" < 282,288d279 < """Create a mask for the specified tokens based on the given order. < < Args: < order (str): The order of parallelism types (e.g., 'tp-dp-pp'). < token (str): The specific parallelism types to include in the mask, < separated by hyphens (e.g., 'tp-dp'). < """ 297c288 < """Get rank group by input token. --- > '''Get rank group by input token. 299c290 < Args: --- > Arguments: 312c303 < """ --- > ''' 887c878 < """Check if model- and data-parallel groups are initialized.""" --- > """Check if model and data parallel groups are initialized.""" 898c889 < """Get the model-parallel group the caller rank belongs to.""" --- > """Get the model parallel group the caller rank belongs to.""" 909c900 < """Get the tensor-model-parallel group the caller rank belongs to.""" --- > """Get the tensor model parallel group the caller rank belongs to.""" 918c909 < """Get the pipeline-model-parallel group the caller rank belongs to.""" --- > """Get the pipeline model parallel group the caller rank belongs to.""" 926c917 < """Get the data-parallel group the caller rank belongs to.""" --- > """Get the data parallel group the caller rank belongs to.""" 938c929 < """Get the Gloo data-parallel group the caller rank belongs to.""" --- > """Get the data parallel group-gloo the caller rank belongs to.""" 950c941 < """Get the context-parallel group the caller rank belongs to.""" --- > """Get the context parallel group the caller rank belongs to.""" 957c948 < """Get all global ranks of the context-parallel group that the caller rank belongs to.""" --- > """Get all global ranks of the context parallel group that the caller rank belongs to.""" 977c968 < def get_amax_reduction_group(with_context_parallel=False, tp_only_amax_red=False): --- > def get_amax_reduction_group(with_context_parallel=False): 980,1000c971,979 < if not tp_only_amax_red: < assert ( < _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None < ), 'FP8 amax reduction group is not initialized' < return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP < else: < assert ( < _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None < ), 'FP8 amax reduction group is not initialized' < return _TENSOR_AND_CONTEXT_PARALLEL_GROUP < else: < if not tp_only_amax_red: < assert ( < _TENSOR_AND_DATA_PARALLEL_GROUP is not None < ), 'FP8 amax reduction group is not initialized' < return _TENSOR_AND_DATA_PARALLEL_GROUP < else: < assert ( < _TENSOR_MODEL_PARALLEL_GROUP is not None < ), 'FP8 amax reduction group is not initialized' < return _TENSOR_MODEL_PARALLEL_GROUP --- > assert ( > _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None > ), 'FP8 amax reduction group is not initialized' > return _TENSOR_AND_CONTEXT_PARALLEL_GROUP > else: > assert ( > _TENSOR_MODEL_PARALLEL_GROUP is not None > ), 'FP8 amax reduction group is not initialized' > return _TENSOR_MODEL_PARALLEL_GROUP 1004c983 < """Get the tensor- and data-parallel group the caller rank belongs to.""" --- > """Get the tensor and data parallel group the caller rank belongs to.""" 1018c997 < """Get the tensor- and context-parallel group the caller rank belongs to.""" --- > """Get the tensor and context parallel group the caller rank belongs to.""" 1026c1005 < """Get the expert-model-parallel group the caller rank belongs to.""" --- > """Get the expert model parallel group the caller rank belongs to.""" 1034c1013 < """Get the tensor- and expert-parallel group the caller rank belongs to.""" --- > """Get the tensor and expert parallel group the caller rank belongs to.""" 1042c1021 < """Get the data-modulo-expert-parallel group the caller rank belongs to.""" --- > """Get the data modulo expert parallel group the caller rank belongs to.""" 1056c1035 < """Get the Gloo data-modulo-expert-parallel group the caller rank belongs to.""" --- > """Get the data modulo expert parallel group gloo the caller rank belongs to.""" 1070c1049 < """Sets the expert-model-parallel world size.""" --- > """Sets the expert model parallel world size.""" 1076c1055 < """Set the tensor-model-parallel size""" --- > """Set the tensor model parallel size""" 1082c1061 < """Set the pipeline-model-parallel size""" --- > """Set the pipeline model parallel size""" 1088c1067 < """Set the pipeline-model-parallel size""" --- > """Set the pipeline model parallel size""" 1094c1073 < """Return world size for the tensor-model-parallel group.""" --- > """Return world size for the tensor model parallel group.""" 1102c1081 < """Return world size for the pipeline-model-parallel group.""" --- > """Return world size for the pipeline model parallel group.""" 1109c1088 < # Implicit assumption that each PP group is the same size. --- > # I am assuming that each pp group is the same size. 1120c1099 < """Set expert-model-parallel rank.""" --- > """Set expert model parallel rank.""" 1126c1105 < """Set tensor-model-parallel rank.""" --- > """Set tensor model parallel rank.""" 1132c1111 < """Set pipeline-model-parallel rank.""" --- > """Set pipeline model parallel rank.""" 1138c1117 < """Set pipeline-model-parallel split rank. DEPRECATED.""" --- > """Set pipeline model parallel split rank. DEPRECATED.""" 1144c1123 < """Return caller's rank for the tensor-model-parallel group.""" --- > """Return my rank for the tensor model parallel group.""" 1152c1131 < """Return caller's rank for the pipeline-model-parallel group.""" --- > """Return my rank for the pipeline model parallel group.""" 1159c1138 < # Assume that if the caller exist in multiple PP groups, then it has the same index. --- > # I am assuming that if i exist in multiple pp groups, then I am in the same index. 1172c1151 < """Return pipeline-model-parallel split rank.""" --- > """Return pipeline model parallel split rank.""" 1189c1168 < """Return True if in the last pipeline-model-parallel stage, False otherwise.""" --- > """Return True if in the last pipeline model-parallel stage, False otherwise.""" 1337c1316,1317 < """Return the global rank of the first stage in the current rank's pipeline.""" --- > """Return the global rank of the first process in the pipeline for the > current tensor parallel group""" 1349c1329,1330 < """Return the global rank of the last stage in the current rank's pipeline.""" --- > """Return the global rank of the last process in the pipeline for the > current tensor parallel group""" 1356,1359c1337,1339 < """Return the global rank that follows the caller in the pipeline, for each < pipeline-parallel group that the rank is part of. < < If it is just part of one group, an int is returned, otherwise a list of ints. --- > """Return the global rank that follows the caller in the pipeline, for each pipeline group that > the rank is part of. If it's just part of one group, an int is returned, > otherwise a list of ints. 1374,1377c1354,1356 < """Return the global rank that precedes the caller in the pipeline, for each < pipeline-parallel group that the rank is part of. < < If it is just part of one group, an int is returned, otherwise a list of ints. --- > """Return the global rank that preceeds the caller in the pipeline, for each pipeline group that > the rank is part of. If it's just part of one group, an int is returned, > otherwise a list of ints. 1411c1390 < """Return caller's rank in the data-parallel group.""" --- > """Return my rank for the data parallel group.""" 1432c1411 < """Return caller's rank in the context-parallel group.""" --- > """Return my rank for the context parallel group.""" 1440c1419 < """Return world size for the tensor and context-parallel group.""" --- > """Return world size for the tensor and context parallel group""" 1448c1427 < """Return caller's rank in the joint tensor-model-parallel and context-parallel group.""" --- > """Return my rank for the tensor and context parallel group.""" 1456c1435 < """Return world size for the expert-model-parallel group.""" --- > """Return world size for the expert model parallel group""" 1482c1461 < """Return caller's rank in the expert-model-parallel group.""" --- > """Return my rank for the expert parallel group""" 1495c1474 < """Return caller's rank in the context-parallel group.""" --- > """Return my rank for the context parallel group.""" 1505c1484 < """Return caller's rank in the joint tensor- and expert-model-parallel group.""" --- > """Return my rank for the tensor and expert parallel group""" 1513c1492 < """Initialize global buffer.""" --- > """Initialize global buffer""" 1532,1533d1510 < """Get caller's rank in tensor-model-parallel, data-parallel, context-parallel, < pipeline-model-parallel and expert-model-parallel groups.""" 1645,1647d1621 < < global _MOE_LAYER_WISE_LOGGING_TRACKER < _MOE_LAYER_WISE_LOGGING_TRACKER = {} Binary files ./megatron/core/pipeline_parallel/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/pipeline_parallel/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/pipeline_parallel/__pycache__/p2p_communication.cpython-310.pyc and ../megatron-lm/megatron/core/pipeline_parallel/__pycache__/p2p_communication.cpython-310.pyc differ Binary files ./megatron/core/pipeline_parallel/__pycache__/schedules.cpython-310.pyc and ../megatron-lm/megatron/core/pipeline_parallel/__pycache__/schedules.cpython-310.pyc differ diff -rN ./megatron/core/pipeline_parallel/schedules.py ../megatron-lm/megatron/core/pipeline_parallel/schedules.py 594,600d593 < # Disable config.grad_sync_func and config.param_sync_func if only running forward passes. < # They will be re-enabled at the end of this function. < grad_sync_func, param_sync_func = None, None < if forward_only: < grad_sync_func, param_sync_func = config.grad_sync_func, config.param_sync_func < config.grad_sync_func, config.param_sync_func = None, None < 1150,1153d1142 < < # Restore config.grad_sync_func and config.param_sync_func. < if forward_only: < config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func Binary files ./megatron/core/__pycache__/config_logger.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/config_logger.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/enums.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/enums.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/inference_params.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/inference_params.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/jit.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/jit.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/model_parallel_config.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/model_parallel_config.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/num_microbatches_calculator.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/num_microbatches_calculator.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/optimizer_param_scheduler.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/optimizer_param_scheduler.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/package_info.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/package_info.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/packed_seq_params.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/packed_seq_params.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/parallel_state.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/parallel_state.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/timers.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/timers.cpython-310.pyc differ Binary files ./megatron/core/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/core/__pycache__/utils.cpython-310.pyc differ diff -rN ./megatron/core/requirements.txt ../megatron-lm/megatron/core/requirements.txt 1,2c1 < torch < packaging --- > torch \ No newline at end of file diff -rN ./megatron/core/ssm/mamba_block.py ../megatron-lm/megatron/core/ssm/mamba_block.py 17,19d16 < from megatron.core.dist_checkpointing.mapping import ShardedStateDict < from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding < from megatron.core.extensions.transformer_engine import TENorm 22a20 > from megatron.core.transformer.custom_layers.transformer_engine import TENorm 27d24 < from megatron.core.transformer.utils import sharded_state_dict_default 55,58c52,53 < # > A modified initialization which accounts for the accumulation on the < # > residual path with model depth. Scale < # > the weights of residual layers at initialization by a factor of < # > 1/√N where N is the # of residual layers. --- > # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale > # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 61,62c56 < # Reference (Megatron-LM): < # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py --- > # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 75,78d68 < """ < A class for the module specs for the MambaStack. < """ < 85,110d74 < """ < Constructor for the MambaStack class. < < Args: < config (TransformerConfig): the transformer configuration < submodules (MambaStackSubmodules): the submodules for the stack < mamba_ssm_ngroups (int, optional): the number of groups for the < MAMBA SSM. Defaults to 8. < residual_in_fp32 (bool, optional): whether to do residual connections < in fp32. Defaults to False. < pre_process (bool, optional): whether to include an embedding layer. < Defaults to True. < hybrid_attention_ratio (float, optional): the target ratio of attention layers to < total layers. Defaults to 0.0. < hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total < layers. Defaults to 0.0. < hybrid_override_pattern (str, optional): the hybrid layer pattern to override < with. Defaults to None. < post_layer_norm (bool, optional): whether to include a final layer norm. < Defaults to True. < post_process (bool, optional): whether to include an output layer. < Defaults to True. < device (optional): the device to use. Defaults to None. < dtype (optional): the data type to use. Defaults to None. < """ < 204,213d167 < """ < Allocate inference cache for each layer. < < Args: < batch_size (int): The batch size to use for inference. < max_seqlen (int): The maximum sequence length to use < for inference. < dtype (optional): The data type to use for allocation. < Defaults to the data type of the model. < """ 236,250d189 < """ < Forward function of the MambaStack class. < < It either returns the Loss values if labels are given or the < final hidden units < < Args: < hidden_states (Tensor): the input tensor. < attention_mask (Tensor): the attention mask. < inference_params (InferenceParams): the inference parameters. < rotary_pos_emb (Tensor, optional): the rotary positional embeddings. < Defaults to None. < Returns: < Tensor: the output tensor. < """ 256,257c195 < # NOTE(bnorick): match InferenceParams attributes for < # mamba_ssm.utils.generation.InferenceParams, --- > # NOTE(bnorick): match InferenceParams attributes for mamba_ssm.utils.generation.InferenceParams, 287,337d224 < < def sharded_state_dict( < self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None < ) -> ShardedStateDict: < """ < Returns a sharded state dictionary for the current object. < < This function constructs a sharded state dictionary by iterating over the layers < in the current object, computing the sharded state dictionary for each layer, < and combining the results into a single dictionary. < < Parameters: < prefix (str): The prefix to use for the state dictionary keys. < sharded_offsets (tuple): The sharded offsets to use for the state dictionary. < metadata (dict): Additional metadata to use when computing the sharded state dictionary. < < Returns: < dict: The sharded state dictionary for the current object. < """ < < sharded_state_dict = {} < layer_prefix = f'{prefix}layers.' < < for local_layer_idx, layer in enumerate(self.layers): < < global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 < state_dict_prefix = ( < f'{layer_prefix}{local_layer_idx}.' # module list index in MambaBlock < ) < < sharded_prefix = f'{layer_prefix}{global_layer_offset}.' < sharded_pp_offset = [] < < layer_sharded_state_dict = layer.sharded_state_dict( < state_dict_prefix, sharded_pp_offset, metadata < ) < < replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) < < sharded_state_dict.update(layer_sharded_state_dict) < < # Add modules other than self.layers < for name, module in self.named_children(): < if not module is self.layers: < sharded_state_dict.update( < sharded_state_dict_default( < module, f'{prefix}{name}.', sharded_offsets, metadata < ) < ) < < return sharded_state_dict diff -rN ./megatron/core/ssm/mamba_mixer.py ../megatron-lm/megatron/core/ssm/mamba_mixer.py 10,11c10,11 < from dataclasses import dataclass, replace < from typing import List, Optional, Union --- > from dataclasses import dataclass > from typing import Union 17,18d16 < from megatron.core.dist_checkpointing import ShardedTensor < from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory 24,27d21 < from megatron.core.transformer.utils import ( < make_sharded_tensors_for_checkpoint, < sharded_state_dict_default, < ) 55,67d48 < class ExtendedRMSNorm(RMSNormGated): < """ < RMSNormGated with sharded state dict. < """ < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < """Sharding along axis 0, bias not sharded""" < state_dict = self.state_dict(prefix='', keep_vars=True) < return make_sharded_tensors_for_checkpoint( < state_dict, prefix, {'weight': 0}, sharded_offsets < ) < < 70,73d50 < """ < Contains the module specs for the input and output linear layers. < """ < 79,106d55 < """ < Args: < config: The config of the model. < submodules: Contains the module specs for the input and output linear layers. < d_model: The hidden size of the model. < d_state: The state size of the SSM. < d_conv: The number of channels in the causal convolution. < conv_init: The initialization range for the causal convolution weights. < expand: The expansion factor for the SSM. < headdim: The hidden size of each attention head. < ngroups: The number of attention heads. < A_init_range: The initialization range for the attention weights. < D_has_hdim: Whether the D parameter has the same number of dimensions as the hidden < state. < rmsnorm: Whether to use root mean square normalization. < norm_before_gate: Whether to apply normalization before the gating mechanism. < dt_min: The minimum value of the dt parameter. < dt_max: The maximum value of the dt parameter. < dt_init: The initialization value of the dt parameter. < dt_scale: The scaling factor for the dt parameter. < dt_init_floor: The minimum value of the dt parameter after initialization. < bias: Whether to use bias in the linear layers. < conv_bias: Whether to use bias in the causal convolution. < chunk_size: The chunk size for the fused kernel. < use_mem_eff_path: Whether to use the memory-efficient path for the Mamba model. < layer_number: The layer number of this Mamba layer. < """ < 171c120 < self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, # AB CD E --- > self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, 181c130 < conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state # A CD --- > conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state 183d131 < # weight dim: [conv_dim, conv_dim, d_conv] 216,217c164 < # Our initialization would set all Linear.bias to zero, < # need to mark this one as _no_reinit --- > # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 219,221c166 < # Just to be explicit. Without this we already don't < # put wd on dt_bias because of the check < --- > # Just to be explicit. Without this we already don't put wd on dt_bias because of the check 246c191 < self.norm = ExtendedRMSNorm( --- > self.norm = RMSNormGated( 408,410d352 < """ < Performs inference step for decoding < """ 535,537d476 < """ < allocate inference cache < """ 581,718d519 < < def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): < sharded_state_dict = {} < # Parameters < self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) < sharded_state_dict = make_sharded_tensors_for_checkpoint( < sharded_state_dict, < prefix, < tensor_parallel_layers_axis_map={ < 'A_log': 0, < 'dt_bias': 0, < 'D': 0, < }, # parameters sharded across TP < sharded_offsets=sharded_offsets, < ) < # Submodules < for name, module in self.named_children(): < if name == 'conv1d': < # Add TP sharding for Conv1d < module_sd = module.state_dict(prefix='', keep_vars=True) < module_sharded_sd = make_sharded_tensors_for_checkpoint( < module_sd, f'{prefix}{name}.', {f'weight': 0, f'bias': 0}, sharded_offsets < ) < < else: < module_sharded_sd = sharded_state_dict_default( < module, f'{prefix}{name}.', sharded_offsets, metadata < ) < < sharded_state_dict.update(module_sharded_sd) < < # At this point the TP sharding is correctly defined fo each tensor, but some of the tensors < # must be additionally split into separate parts < # in_proj < in_proj_dim = ( < self.d_inner_local * 2 + 2 * self.ngroups_local * self.d_state + self.nheads_local < ) < assert sharded_state_dict[f'{prefix}in_proj.weight'].data.size(0) == in_proj_dim, ( < in_proj_dim, < sharded_state_dict[f'{prefix}in_proj.weight'], < ) < < sharded_state_dict[f'{prefix}in_proj.weight'] = _split_tensor_factory( < sharded_state_dict[f'{prefix}in_proj.weight'], < [ < self.d_inner_local, < self.d_inner_local, < self.ngroups_local * self.d_state, < self.ngroups_local * self.d_state, < self.nheads_local, < ], < ['z', 'x', 'B', 'C', 'dt'], < 0, < ) < < conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state < assert sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == conv_dim, ( < conv_dim, < sharded_state_dict[f'{prefix}conv1d.weight'], < ) < assert sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == conv_dim, ( < conv_dim, < sharded_state_dict[f'{prefix}conv1d.bias'], < ) < < for conv_layer_name in ['conv1d.weight', 'conv1d.bias']: < sharded_state_dict[f'{prefix}{conv_layer_name}'] = _split_tensor_factory( < sharded_state_dict[f'{prefix}{conv_layer_name}'], < [ < self.d_inner_local, < self.ngroups_local * self.d_state, < self.ngroups_local * self.d_state, < ], < ['x', 'B', 'C'], < 0, < ) < < return sharded_state_dict < < < def _split_tensor_factory( < orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int < ) -> ShardedTensorFactory: < """Builds a factory that splits a given ShardedTensor into several independent chunks.""" < assert isinstance(orig_sh_ten, ShardedTensor), type(orig_sh_ten) < orig_sh_ten_no_data = orig_sh_ten.without_data() # remove `data` reference < < if sum(split_sections) != orig_sh_ten_no_data.local_shape[split_dim]: < raise ValueError( < f'Split sections must cover the whole dimension size, ' < f'got {split_sections=} vs dimensions size ' < f'{orig_sh_ten_no_data.local_shape[split_dim]}' < ) < < assert not isinstance( < split_sections, int < ), 'Splitting into predefined section sizes is supported (`split_sections` must be a list)' < assert len(split_sections) == len(split_names), (len(split_sections), len(split_names)) < < @torch.no_grad() < def sh_ten_build_fn( < key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice] < ): < factory_sh_ten = replace( < orig_sh_ten_no_data, < key=key, < data=t, < dtype=t.dtype, < replica_id=replica_id, < flattened_range=flattened_range, < ) < < chunk_sh_tens = [] < split_start = 0 < for split_size, split_name in zip(split_sections, split_names): < split_chunks = factory_sh_ten.narrow(split_dim, split_start, split_size) < for sh_ten in split_chunks: < sh_ten.key = f'{sh_ten.key}.{split_name}' < chunk_sh_tens.extend(split_chunks) < split_start += split_size < < assert split_start == orig_sh_ten_no_data.local_shape[split_dim], ( < split_start, < orig_sh_ten_no_data.local_shape[split_dim], < ) < assert sum(sh_ten.data.numel() for sh_ten in chunk_sh_tens) == t.numel(), ( < chunk_sh_tens, < t.shape, < ) < return chunk_sh_tens < < @torch.no_grad() < def sh_ten_merge_fn(sub_state_dict): < return torch.cat(sub_state_dict) < < return ShardedTensorFactory( < orig_sh_ten.key, orig_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, orig_sh_ten.replica_id < ) diff -rN ./megatron/core/tensor_parallel/layers.py ../megatron-lm/megatron/core/tensor_parallel/layers.py 44a45,47 > import torch._dynamo > torch._dynamo.config.suppress_errors = True > 378a382 > @torch.compile(mode="max-autotune-no-cudagraphs") 743a748 > self.is_mlp = True 769c774,784 < self.weight = Parameter( --- > if self.is_mlp and self.input_size % 2048 == 0: > print("+++++padding is done here") > tmp_weight = Parameter(torch.empty( > self.output_size_per_partition, > self.input_size+32, > device=torch.cuda.current_device(), > dtype=config.params_dtype, > )) > self.weight = tmp_weight[:,0:self.input_size] > else: > self.weight = Parameter( 775,776c790,791 < ) < ) --- > ) > ) 1033c1048 < --- > self.is_mlp = True 1059c1074,1084 < self.weight = Parameter( --- > if self.is_mlp and self.input_size_per_partition % 2048 == 0: > print("------------padding is done here") > tmp_weight = Parameter(torch.empty( > self.output_size, > self.input_size_per_partition+32, > device=torch.cuda.current_device(), > dtype=config.params_dtype, > )) > self.weight = tmp_weight[:,0:self.input_size_per_partition] > else: > self.weight = Parameter( 1065,1066c1090,1091 < ) < ) --- > ) > ) Binary files ./megatron/core/tensor_parallel/__pycache__/cross_entropy.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/cross_entropy.cpython-310.pyc differ Binary files ./megatron/core/tensor_parallel/__pycache__/data.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/data.cpython-310.pyc differ Binary files ./megatron/core/tensor_parallel/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/tensor_parallel/__pycache__/layers.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/layers.cpython-310.pyc differ Binary files ./megatron/core/tensor_parallel/__pycache__/mappings.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/mappings.cpython-310.pyc differ Binary files ./megatron/core/tensor_parallel/__pycache__/random.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/random.cpython-310.pyc differ Binary files ./megatron/core/tensor_parallel/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/core/tensor_parallel/__pycache__/utils.cpython-310.pyc differ diff -rN ./megatron/core/tensor_parallel/random.py ../megatron-lm/megatron/core/tensor_parallel/random.py 7a8 > from importlib.metadata import version 9a11 > from pkg_resources import packaging 15a18 > get_data_parallel_rank, 16a20 > get_tensor_model_parallel_group, 17a22 > get_tensor_model_parallel_world_size, 19c24 < from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_data --- > from megatron.core.utils import safely_set_viewless_tensor_data 64d68 < """Get the expert parallel rng tracker name""" 70d73 < """Get the data parallel rng tracker name""" 88d90 < """Checks if the internal RNG state has been set wirth set_states().""" 167,171d168 < """Create the RNG tracker. 'use_te_rng_tracker' determines whether to use < Megatron or TransformerEngine's implementation. < In particular, TransformerEngine's implementation is cudagraphable and supports FP8. < """ < 176d172 < 178,180c174,175 < if not is_te_min_version("1.5.0"): < raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5") < from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker --- > try: > import transformer_engine.pytorch as te 182c177,183 < _CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker() --- > _te_version = packaging.version.Version(version("transformer-engine")) > if _te_version < packaging.version.Version("1.5.0"): > raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5") > except ImportError: > raise RuntimeError("use_te_rng_tracker requires TransformerEngine, but not installed") > if use_te_rng_tracker: > _CUDA_RNG_STATE_TRACKER = te.distributed.CudaRNGStatesTracker() 188c189 < def get_cuda_rng_tracker(use_te_rng_tracker=False): --- > def get_cuda_rng_tracker(): 190c191 < initialize_rng_tracker(use_te_rng_tracker) --- > initialize_rng_tracker() 202,207c203,204 < default state: This is for data parallelism and is the same among a set of model parallel GPUs < but different across different model parallel groups. This is used for example for dropout < in the non-tensor-model-parallel regions. < tensor-model-parallel state: This state is different among a set of model parallel GPUs, < but the same across data parallel groups. This is used for example for dropout < in model parallel regions. --- > default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-tensor-model-parallel regions. > tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. 240d236 < """Forward call""" 267d262 < """Backward call""" diff -rN ./megatron/core/transformer/cuda_graphs.py ../megatron-lm/megatron/core/transformer/cuda_graphs.py 1,313d0 < # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. < < import logging < import time < from enum import Enum < < import torch < < from megatron.core.transformer.module import MegatronModule < < try: < from transformer_engine.pytorch import make_graphed_callables < from transformer_engine.pytorch.fp8 import FP8GlobalStateManager < < HAVE_TE_GRAPHS = True < except: < HAVE_TE_GRAPHS = False < < < class GraphStatus(Enum): < """An Enum to track if a cudagraph is ready to perform a forward or backward pass.""" < < FWD_READY = 0 < BWD_READY = 1 < < < class GraphStatusFunc(torch.autograd.Function): < """Inserts a node into the autograd graph that tracks whether an object has an outstanding < backward pass by toggling the value of GraphStatus. This is mainly used to detect when to create < multiple graphs per transformer layer for pipeline parallelism. < We don't use backward module hooks as they change forward output tensors to views, see: < https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook < """ < < @staticmethod < def forward(ctx, runner, obj): < """Occurs immediately before the graph's forward pass. < Marks the graph's backward pass as ready.""" < ctx.runner = runner < runner.status = GraphStatus.BWD_READY < return obj < < @staticmethod < def backward(ctx, grad): < """Occurs immediately after the graph's backward pass. < Marks the graph's forward pass as ready.""" < assert ctx.runner.status == GraphStatus.BWD_READY < ctx.runner.status = GraphStatus.FWD_READY < return None, grad < < < class TensorDescription: < """Records the attributes of a tensor. Used to check if a < tensor argument matches the tensor with which the module < was graph captured with.""" < < def __init__(self, tensor): < self.shape = tuple(tensor.shape) < self.dtype = tensor.dtype < self.device = tensor.device < < def matches_tensor(self, tensor): < """Check if 'tensor' matches the attributes of this TensorDescription.""" < < assert torch.is_tensor(tensor) < return ( < tensor.shape == self.shape < and tensor.dtype == self.dtype < and tensor.device == self.device < ) < < < class CudaGraphCallable(torch.nn.Module): < """Wraps a module to be cudagraphable, records the output of the cudagraph. < Reinserts non-tensor args, kwargs that were previously filtered out by 'get_tensor_args'. < """ < < def __init__(self, module, groundtruth_args, groundtruth_kwargs): < super().__init__() < self.add_module('base_module', module) < < # The Pytorch cudagraph API requires only tensor inputs, so we strip < # non-tensor arguments and reinsert them in forward() using these groundtruth attributes. < # We will also check future calls to the cudagraph against these to ensure the cudagraph < # is called with the same inputs as it was captured with. < self.groundtruth_outputs = [] < self.groundtruth_args = tuple( < TensorDescription(a) if torch.is_tensor(a) else a for a in groundtruth_args < ) < self.groundtruth_kwargs = { < k: TensorDescription(v) if torch.is_tensor(v) else v < for k, v in groundtruth_kwargs.items() < } < < def forward(self, *arg_tensors, **kwarg_tensors): < """Call the forward pass of the cudagraph. Also checks the outputs < of the cudagraph matches what the graph was traced with.""" < < args = list(self.groundtruth_args) < arg_tensors = list(arg_tensors) < for idx, groundtruth_arg in enumerate(self.groundtruth_args): < if isinstance(groundtruth_arg, TensorDescription): < args[idx] = arg_tensors.pop(0) < < kwargs = dict(self.groundtruth_kwargs) < for k, v in self.groundtruth_kwargs.items(): < if isinstance(v, TensorDescription): < kwargs[k] = kwarg_tensors[k] < < # Use forward() instead of __call__ to avoid triggering hooks < out = self.base_module.forward(*args, **kwargs) < if torch.is_tensor(out): < out = tuple(out) < < self.groundtruth_outputs = [TensorDescription(o) if torch.is_tensor(o) else o for o in out] < < out = tuple(o for o in out if torch.is_tensor(o)) < assert ( < len(out) > 0 < ), """A graphed module returned no tensors in training mode, however the graphed module < must output at least one tensor, so that a corresponding backward node < may be registered in the autograd graph.""" < < if len(out) == 1: < return out[0] < return out < < < class CudaGraphRunner(torch.nn.Module): < """Wraps a single cudagraph and its expected arguments. Checks that < the provided args are the same as what the graph was traced with. < """ < < def __init__(self, graphed_module, wrapped_module): < super().__init__() < < self.graphed_module = graphed_module < self.groundtruth_args = wrapped_module.groundtruth_args < self.groundtruth_kwargs = wrapped_module.groundtruth_kwargs < self.groundtruth_outputs = wrapped_module.groundtruth_outputs < self.status = GraphStatus.FWD_READY < < def static_args_match(self, args, kwargs): < """Check the the passed args, kwargs match with the arg, kwargs < the graph was created with.""" < < def check(val, ref): < if isinstance(ref, TensorDescription): < return ref.matches_tensor(val) < return ref == val < < if len(args) != len(self.groundtruth_args): < return False < for idx, groundtruth_arg in enumerate(self.groundtruth_args): < if not check(args[idx], groundtruth_arg): < return False < < if kwargs.keys() != self.groundtruth_kwargs.keys(): < return False < for k, v in self.groundtruth_kwargs.items(): < if not check(kwargs[k], v): < return False < return True < < def forward(self, args, kwargs, is_first_microbatch=None): < """Call the forward pass of the cuda graph.""" < if self.training and torch.is_grad_enabled(): < args = list(args) < for pos in range(len(args)): < if torch.is_tensor(args[pos]): < args[pos] = GraphStatusFunc.apply(self, args[pos]) < for k, v in kwargs.items(): < if torch.is_tensor(v): < kwargs[k] = GraphStatusFunc.apply(self, v) < < ret_tensors = self.graphed_module(is_first_microbatch=is_first_microbatch, *args, **kwargs) < ret_tensors = [ret_tensors] if torch.is_tensor(ret_tensors) else list(ret_tensors) < out = tuple( < ret_tensors.pop(0) if isinstance(o, TensorDescription) else o < for o in self.groundtruth_outputs < ) < < # Check that the static graph matches what was recorded during graph capture < assert len(out) == len(self.groundtruth_outputs) < for idx, o in enumerate(self.groundtruth_outputs): < if isinstance(o, TensorDescription): < assert o.matches_tensor(out[idx]) < else: < assert o == out[idx] < < if len(out) == 1: < return out[0] < return out < < < class CudaGraphManager(torch.nn.Module): < """Creates and runs cudagraphs for a megatron module.""" < < def __init__(self): < super().__init__() < self.cudagraph_runners = [] < self.is_first_microbatch = True < assert HAVE_TE_GRAPHS, "CudaGraphManager currently requires TransformerEngine" < < # Cudagraph stream capture requires no operations on the default stream prior to the < # capture, so change to a side stream. At graph capture change it back. < self.stream = torch.cuda.current_stream() < torch.cuda.set_stream(torch.cuda.Stream()) < < def __call__(self, megatron_module, args, kwargs): < """Calls the forward pass of the cudagraphed module. < < Args: < megatron_module (torch.nn.module): The megatron module to be graphed and run < < args (tuple): The positional args to be passed to the module. < < kwargs (dict): The keyword args to be passed to the module. < < """ < < # param.data_ptr() below is used to trigger any hooks that have attached to the parameter. < # Specifically, this is trying to trigger the param sync hook for the APEX optimizer, which < # triggers param syncs by hooking into any param references. < # However cudagraphs disables this, so we workaround by manually referencing params here. < # For more information see: < # https://github.com/NVIDIA/apex/blob/7001836/apex/contrib/optimizers/distributed_fused_adam.py#L885C9 < for param in megatron_module.parameters(): < param.data_ptr() < < runner = None < for _runner in self.cudagraph_runners: < if _runner.static_args_match(args, kwargs) and _runner.status == GraphStatus.FWD_READY: < runner = _runner < break < < if runner is None: < if self.training and torch.is_grad_enabled(): < runner = self.create_cudagraph_module(megatron_module, args, kwargs) < self.cudagraph_runners.append(runner) < logging.getLogger(__name__).info( < f"Creating cudagraph; now have {len(self.cudagraph_runners)}" < ) < else: < # No cudagraphs were found in inference mode, so fallback to eager since < # tensor.requires_grad is needed to correctly trace the backward graph. < return super(MegatronModule, megatron_module).__call__(*args, **kwargs) < < tensor_args, tensor_kwargs = self.get_tensor_args(args, kwargs) < out = runner(tensor_args, tensor_kwargs, is_first_microbatch=self.is_first_microbatch) < self.is_first_microbatch = False < return out < < def get_tensor_args(self, args, kwargs): < """Filter out non-tensor arguments from args and kwargs. < Needed since 'make_graphed_callables' expects Torch.tensor arg, kwargs.""" < tensor_kwargs = {} < for k, v in kwargs.items(): < if torch.is_tensor(v): < tensor_kwargs[k] = v < tensor_args = tuple(arg for arg in args if torch.is_tensor(arg)) < return tensor_args, tensor_kwargs < < def create_cudagraph_module(self, megatron_module, args, kwargs): < """Record the graph capture stream. Runs warmup iterations of < megatron_module, and creates a autograd function, where the < forward, backward functions are the cudagraphs of module's forward, < backward passes. Finally wraps this cudagraph function with a CudaGraphRunner. < """ < < torch.cuda.synchronize() < torch.cuda.set_stream(self.stream) < start = time.time() < < wrapped_module = CudaGraphCallable(megatron_module, args, kwargs) < sample_args, sample_kwargs = self.get_tensor_args(args, kwargs) < < # Cudagraphs require no autograd history recorded on sample inputs < sample_args_detached = tuple(n.detach() for n in sample_args) < sample_kwargs_detached = {k: v.detach() for k, v in sample_kwargs.items()} < sample_args_copy = tuple(torch.clone(n) for n in sample_args_detached) < sample_kwargs_copy = {k: torch.clone(v) for k, v in sample_kwargs_detached.items()} < < # Zero out input args inplace so cudagraph warmup doesnt affect grads < for orig, detach in zip(sample_args, sample_args_detached): < detach.zero_() < detach.requires_grad = orig.requires_grad < for k, detach in sample_kwargs_detached.items(): < detach.zero_() < detach.requires_grad = sample_kwargs[k].requires_grad < < fp8_enabled = megatron_module.config.fp8 is not None < fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_enabled else None < graphed_module = make_graphed_callables( < modules=wrapped_module, < sample_args=sample_args_detached, < sample_kwargs=sample_kwargs_detached, < _order=[1, -1], < allow_unused_input=True, < fp8_enabled=fp8_enabled, < fp8_recipe=fp8_recipe, < fp8_weight_caching=True, < ) < < # Restore zeroed out sample args < # Detach again since pytorch prohibits inplace ops on leaf nodes < for orig, copy in zip(sample_args, sample_args_copy): < orig.detach().copy_(copy) < for k, orig in sample_kwargs.items(): < orig.detach().copy_(sample_kwargs_copy[k]) < < logging.getLogger(__name__).info(f'Time spent in cudagraph capture: {time.time() - start}s') < return CudaGraphRunner(graphed_module, wrapped_module) Binary files ./megatron/core/transformer/custom_layers/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/custom_layers/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/transformer/custom_layers/__pycache__/transformer_engine.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/custom_layers/__pycache__/transformer_engine.cpython-310.pyc differ diff -rN ./megatron/core/transformer/custom_layers/transformer_engine.py ../megatron-lm/megatron/core/transformer/custom_layers/transformer_engine.py 3c3,6 < import warnings --- > import dataclasses > import os > from importlib.metadata import version > from typing import Callable 5,10c8,19 < warnings.warn( < """The 'megatron.core.transformer.custom_layers.transformer_engine' < module is deprecated and will be removed in 0.10.0. Please use < 'megatron.core.extensions.transformer_engine' instead.""", < DeprecationWarning, < stacklevel=2, --- > import torch > import transformer_engine as te > from pkg_resources import packaging > from torch import Tensor > > from megatron.core import ModelParallelConfig, parallel_state > from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding > from megatron.core.packed_seq_params import PackedSeqParams > from megatron.core.parallel_state import ( > get_context_parallel_global_ranks, > get_context_parallel_group, > get_tensor_model_parallel_group, 12c21,913 < from megatron.core.extensions.transformer_engine import * --- > from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name > from megatron.core.tensor_parallel.utils import divide > from megatron.core.transformer.enums import AttnMaskType > from megatron.core.transformer.transformer_config import TransformerConfig > from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint > > > def get_te_version(): > def get_te_version_str(): > if hasattr(te, '__version__'): > return str(te.__version__) > else: > return version("transformer-engine") > > return packaging.version.Version(get_te_version_str()) > > > _te_version = get_te_version() > > > def _get_extra_te_kwargs(config: TransformerConfig): > extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} > > if _te_version >= packaging.version.Version("0.12.0"): > if config.use_cpu_initialization: > extra_transformer_engine_kwargs["device"] = 'cpu' > else: > extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() > return extra_transformer_engine_kwargs > > > def condition_init_method(config, init_method): > return init_method if config.perform_initialization else (lambda w: None) > > > class TENorm: > """ > A conditional wrapper to initialize an instance of Transformer-Engine's > `LayerNorm` or `RMSNorm` based on input > """ > > # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? > def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): > if config.normalization == "LayerNorm": > instance = te.pytorch.LayerNorm( > hidden_size=hidden_size, > eps=eps, > sequence_parallel=config.sequence_parallel, > zero_centered_gamma=config.layernorm_zero_centered_gamma, > **_get_extra_te_kwargs(config), > ) > elif config.normalization == "RMSNorm": > assert hasattr( > te.pytorch, "RMSNorm" > ), "Transformer-Engine >= v0.11 required to use this feature" > instance = te.pytorch.RMSNorm( > hidden_size=hidden_size, > eps=eps, > sequence_parallel=config.sequence_parallel, > zero_centered_gamma=config.layernorm_zero_centered_gamma, > **_get_extra_te_kwargs(config), > ) > else: > raise Exception('Only LayerNorm and RMSNorm are curently supported') > > return instance > > > class TELinear(te.pytorch.Linear): > """ > Wrapper for the Transformer-Engine's `Linear` layer. > > Note that if Megatron's parallel_state has not been initialized > yet, the tp_group passed to TE will be None and must be set later > via set_tensor_parallel_group(). > """ > > def __init__( > self, > input_size: int, > output_size: int, > *, > parallel_mode: str, > config: ModelParallelConfig, > init_method: Callable, > bias: bool, > skip_bias_add: bool, > skip_weight_param_allocation: bool, > tp_comm_buffer_name: str = None, > ): > self.config = config > > # TE returns a zero length Tensor when bias=False and > # return_bias=True, but we prefer None. So in that case we > # tell TE to not return the bias, and return None > # ourselves. This way our forward always returns two values > # and we don't have to deal with the zero length Tensor. > self.te_return_bias = skip_bias_add and bias > self.is_first_microbatch = True > self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache > if skip_weight_param_allocation: > raise ValueError( > 'Transformer Engine linear layers do not support skip_weight_param_allocation' > ) > > extra_kwargs = _get_extra_te_kwargs(config) > > if _te_version >= packaging.version.Version("0.8.0"): > if self.config.tp_comm_overlap: > if _te_version > packaging.version.Version("1.5.0"): > # Use old overlap flags if they were supplied instead > extra_kwargs["ub_overlap_ag"] = ( > self.config.tp_comm_overlap_ag > if hasattr(self.config, "tp_comm_overlap_ag") > else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag > ) > extra_kwargs["ub_overlap_rs"] = ( > self.config.tp_comm_overlap_rs > if hasattr(self.config, "tp_comm_overlap_rs") > else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs > ) > else: > extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag > extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag > extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs > extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs > if _te_version > packaging.version.Version("1.0.0"): > assert ( > tp_comm_buffer_name is not None > ), "Buffer name should be set to configure communication overlap settings" > extra_kwargs["ub_name"] = tp_comm_buffer_name > > super().__init__( > in_features=input_size, > out_features=output_size, > sequence_parallel=self.config.sequence_parallel, > fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, > tp_group=get_tensor_model_parallel_group(check_initialized=False), > tp_size=self.config.tensor_model_parallel_size, > get_rng_state_tracker=( > get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None > ), > init_method=condition_init_method(config, init_method), > bias=bias, > return_bias=self.te_return_bias, > parallel_mode=parallel_mode, > **extra_kwargs, > ) > > def forward(self, x): > _is_first_microbatch = ( > None if self.disable_parameter_transpose_cache else self.is_first_microbatch > ) > out = super().forward(x, is_first_microbatch=_is_first_microbatch) > self.is_first_microbatch = False > > # TE only returns a tuple when return_bias is True, otherwise > # it returns a single Tensor, we always want to return two > # values regardless of the arguments. > if self.te_return_bias: > return out > return out, None > > > class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): > """ > Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines > layernorm and linear layers > """ > > def __init__( > self, > input_size: int, > output_size: int, > *, > config: TransformerConfig, > init_method: Callable, > gather_output: bool, > bias: bool, > skip_bias_add: bool, > is_expert: bool, > skip_weight_param_allocation: bool = False, > tp_comm_buffer_name: str = None, > ): > self.config = config > > if gather_output: > raise ValueError('Transformer Engine linear layers do not support gather_output = True') > > if is_expert: > raise ValueError('Transformer Engine linear layers do not yet support MoE') > > if skip_weight_param_allocation: > raise ValueError( > 'Transformer Engine linear layers do not support skip_weight_param_allocation' > ) > > # TE returns a zero length Tensor when bias=False and > # return_bias=True, but we prefer None. So in that case we > # tell TE to not return the bias, and return None > # ourselves. This way our forward always returns two values > # and we don't have to deal with the zero length Tensor. > self.te_return_bias = skip_bias_add and bias > self.is_first_microbatch = True > self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache > extra_kwargs = _get_extra_te_kwargs(config) > > # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` > if _te_version >= packaging.version.Version("0.11.0"): > extra_kwargs["normalization"] = self.config.normalization > elif self.config.normalization != "LayerNorm": > raise ValueError( > f"Transformer Engine v{_te_version} does not support {self.config.normalization}." > ) > > if _te_version >= packaging.version.Version("0.8.0"): > if self.config.tp_comm_overlap: > extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad > extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad > if _te_version > packaging.version.Version("1.5.0"): > # Use old overlap flags if they were supplied instead > extra_kwargs["ub_overlap_ag"] = ( > self.config.tp_comm_overlap_ag > if hasattr(self.config, "tp_comm_overlap_ag") > else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag > ) > if _te_version > packaging.version.Version("1.6.0.dev0"): > extra_kwargs["ub_overlap_rs_dgrad"] = ( > self.config.tp_comm_overlap_rs_dgrad > if hasattr(self.config, "tp_comm_overlap_rs_dgrad") > else False > ) > if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv: > extra_kwargs["ub_overlap_ag"] = False > extra_kwargs["ub_overlap_rs_dgrad"] = False > > if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1: > extra_kwargs["ub_overlap_ag"] = False > extra_kwargs["ub_overlap_rs_dgrad"] = False > else: > extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag > extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag > if _te_version > packaging.version.Version("1.0.0"): > assert ( > tp_comm_buffer_name is not None > ), "Buffer name should be set to configure communication overlap settings" > extra_kwargs["ub_name"] = tp_comm_buffer_name > > super().__init__( > in_features=input_size, > out_features=output_size, > eps=self.config.layernorm_epsilon, > sequence_parallel=self.config.sequence_parallel, > fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, > tp_group=get_tensor_model_parallel_group(check_initialized=False), > tp_size=self.config.tensor_model_parallel_size, > get_rng_state_tracker=( > get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None > ), > init_method=condition_init_method(config, init_method), > bias=bias, > return_bias=self.te_return_bias, > parallel_mode="column", > return_layernorm_output=False, > zero_centered_gamma=self.config.layernorm_zero_centered_gamma, > **extra_kwargs, > ) > > def forward(self, x): > _is_first_microbatch = ( > None if self.disable_parameter_transpose_cache else self.is_first_microbatch > ) > out = super().forward(x, is_first_microbatch=_is_first_microbatch) > self.is_first_microbatch = False > > # TE only returns a tuple when return_bias is True, otherwise > # it returns a single Tensor, we always want to return two > # values regardless of the arguments. > if self.te_return_bias: > return out > return out, None > > def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): > """Sharding along axis 0, bias sharded""" > state_dict = self.state_dict(prefix='', keep_vars=True) > return make_sharded_tensors_for_checkpoint( > state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets > ) > > > class TEColumnParallelLinear(TELinear): > """ > Wrapper for the Transformer-Engine's `Linear` layer but specialized similar > to megatron's `ColumnParallelLinear` layer. > """ > > def __init__( > self, > input_size: int, > output_size: int, > *, > config: ModelParallelConfig, > init_method: Callable, > gather_output: bool, > bias: bool, > skip_bias_add: bool, > is_expert: bool, > skip_weight_param_allocation: bool = False, > tp_comm_buffer_name: str = None, > ): > if gather_output: > raise ValueError('Transformer Engine linear layers do not support gather_output = True') > > if is_expert: > raise ValueError('Transformer Engine linear layers do not yet support MoE') > > super().__init__( > input_size=input_size, > output_size=output_size, > parallel_mode="column", > config=config, > init_method=condition_init_method(config, init_method), > bias=bias, > skip_bias_add=skip_bias_add, > skip_weight_param_allocation=skip_weight_param_allocation, > tp_comm_buffer_name=tp_comm_buffer_name, > ) > > def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): > """Sharding along axis 0, bias sharded""" > state_dict = self.state_dict(prefix='', keep_vars=True) > return make_sharded_tensors_for_checkpoint( > state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets > ) > > > class TERowParallelLinear(TELinear): > """ > Wrapper for the Transformer-Engine's `Linear` layer but specialized similar > to megatron's `RowParallelLinear` layer. > """ > > def __init__( > self, > input_size: int, > output_size: int, > *, > config: ModelParallelConfig, > init_method: Callable, > bias: bool, > input_is_parallel: bool, > skip_bias_add: bool, > is_expert: bool, > tp_comm_buffer_name: str = None, > ): > if not input_is_parallel: > raise ValueError( > "Transformer Engine linear layers do not support input_is_parallel = False" > ) > > if is_expert: > raise ValueError('Transformer Engine linear layers do not yet support MoE') > > super().__init__( > input_size=input_size, > output_size=output_size, > parallel_mode="row", > config=config, > init_method=condition_init_method(config, init_method), > bias=bias, > skip_bias_add=skip_bias_add, > skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long > tp_comm_buffer_name=tp_comm_buffer_name, > ) > > def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): > """Sharding along axis 1, bias not sharded""" > state_dict = self.state_dict(prefix='', keep_vars=True) > return make_sharded_tensors_for_checkpoint( > state_dict, prefix, {'weight': 1}, sharded_offsets > ) > > > class TEDotProductAttention(te.pytorch.DotProductAttention): > """ > Wrapper for the Transformer-Engine's `DotProductAttention` layer that also > has "flash attention" enabled. > > Note that if Megatron's parallel_state has not been initialized yet, the > tp_group and cp_group passed to TE will be None and must be set later > via set_tensor_parallel_group() and set_context_parallel_group(). > """ > > cp_stream: torch.cuda.Stream = None > > def __init__( > self, > config: TransformerConfig, > layer_number: int, > attn_mask_type: AttnMaskType, > attention_type: str, > attention_dropout: float = None, > ): > self.config = config > self.te_forward_mask_type = False > self.qkv_format: str = 'sbhd' > > if self.config.apply_query_key_layer_scaling != bool( > int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) > ): > raise ValueError( > f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " > f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " > f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " > f"setting query key layer scaling via argument, so these two must match." > ) > > extra_kwargs = {} > if _te_version >= packaging.version.Version("0.11.0"): > extra_kwargs["num_gqa_groups"] = self.config.num_query_groups > elif self.config.num_query_groups != self.config.num_attention_heads: > raise ValueError( > f"Transformer Engine v{_te_version} does not support Grouped Query Attention, " > f"use a newer version of Transformer Engine. " > f"(num_query_groups ({self.config.num_query_groups}) != " > f"num_attention_heads ({self.config.num_attention_heads}))" > ) > > if _te_version >= packaging.version.Version("0.10.0"): > extra_kwargs["attention_type"] = attention_type > # older version don't need attention_type > > if _te_version > packaging.version.Version("0.12.0"): > self.te_forward_mask_type = True > > # Only Transformer-Engine version >= 1.0.0 supports context parallelism > if _te_version >= packaging.version.Version("1.0.0"): > if getattr(TEDotProductAttention, "cp_stream") is None: > TEDotProductAttention.cp_stream = torch.cuda.Stream() > extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) > extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( > check_initialized=False > ) > extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream > else: > assert ( > self.config.context_parallel_size == 1 > ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" > > if self.config.deterministic_mode: > if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: > raise RuntimeError( > "deterministic_mode is on and we are using DotProductAttention from " > "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. " > f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}." > ) > > if config.window_size is not None: > # Check version > assert _te_version >= packaging.version.Version("1.2.0"), ( > f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support" > "sliding window attention." > ) > extra_kwargs['window_size'] = config.window_size > > super().__init__( > num_attention_heads=self.config.num_attention_heads, > kv_channels=self.config.kv_channels, > attention_dropout=( > self.config.attention_dropout if attention_dropout is None else attention_dropout > ), > attn_mask_type=attn_mask_type.name, > sequence_parallel=self.config.sequence_parallel, > tp_size=self.config.tensor_model_parallel_size, > get_rng_state_tracker=( > get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None > ), > tp_group=get_tensor_model_parallel_group(check_initialized=False), > layer_number=layer_number, > **extra_kwargs, > ) > > def forward( > self, > query: Tensor, > key: Tensor, > value: Tensor, > attention_mask: Tensor, > attn_mask_type: AttnMaskType, > packed_seq_params: PackedSeqParams = None, > ): > packed_seq_kwargs = ( > dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} > ) > # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set > # after init > if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"): > self.qkv_format = 'bshd' > > qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) > > if _te_version < packaging.version.Version("1.3.0"): > # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H > # copies (#555) > # These two arguments did not exist prior to 1.3.0 > packed_seq_kwargs.pop("max_seqlen_q", None) > packed_seq_kwargs.pop("max_seqlen_kv", None) > > if self.config.apply_rope_fusion and qkv_format == 'bshd': > query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] > # In PyTorch, the following two tensors are in fact the same: > # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) > # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) > # Stride for a dimension that is 1 has no meaning, so tensors created two different ways > # can have same shape but different strides. > # We unify them to the first one to pass the stride check in TE > if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): > value = value.as_strided(value.shape, key.stride()) > > if self.te_forward_mask_type: > if qkv_format == 'thd' and _te_version >= packaging.version.Version("1.7.0"): > # thd format uses flash attention with cuDNN kernel which requires is_padding=True, > # so the only acceptable mask types are `padding_causal` and `padding`. These do not > # necessarily indicate there are padded tokens in the sequence. > if attn_mask_type == AttnMaskType.causal: > attn_mask_type = AttnMaskType.padding_causal > elif attn_mask_type == AttnMaskType.no_mask: > attn_mask_type = AttnMaskType.padding > core_attn_out = super().forward( > query, > key, > value, > attention_mask, > attn_mask_type=attn_mask_type.name, > **packed_seq_kwargs, > ) > else: > core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs) > > if self.config.apply_rope_fusion and qkv_format == 'bshd': > return core_attn_out.transpose(0, 1) > else: > return core_attn_out > > > if _te_version >= packaging.version.Version("1.9.0.dev0"): > > class TEGroupedLinear(te.pytorch.GroupedLinear): > """ > Wrapper for the Transformer-Engine's `GroupedLinear` layer. > > Note that if Megatron's parallel_state has not been initialized > yet, the tp_group passed to TE will be None and must be set later > via set_tensor_parallel_group(). > """ > > def __init__( > self, > num_gemms: int, > input_size: int, > output_size: int, > *, > parallel_mode: str, > config: ModelParallelConfig, > init_method: Callable, > bias: bool, > skip_bias_add: bool, > is_expert: bool = False, > tp_comm_buffer_name: str = None, > ): > self.config = config > > # TE returns a zero length Tensor when bias=False and > # return_bias=True, but we prefer None. So in that case we > # tell TE to not return the bias, and return None > # ourselves. This way our forward always returns two values > # and we don't have to deal with the zero length Tensor. > self.te_return_bias = skip_bias_add and bias > self.is_first_microbatch = True > self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache > > extra_kwargs = _get_extra_te_kwargs(config) > extra_kwargs["ub_name"] = tp_comm_buffer_name > > self.expert_parallel = self.config.expert_model_parallel_size > 1 > if self.expert_parallel: > extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() > > # For MoE models, the comms between TP and EP group is explicitly handled by > # MoE token dispatcher. So we disable comms by making TE agnostic of model parallel. > self.explicit_expert_comm = is_expert and ( > config.tensor_model_parallel_size > 1 or self.expert_parallel > ) > tp_group = get_tensor_model_parallel_group(check_initialized=False) > if self.explicit_expert_comm and config.moe_extended_tp: > tp_size = parallel_state.get_tensor_and_expert_parallel_world_size() > else: > tp_size = parallel_state.get_tensor_model_parallel_world_size() > if self.explicit_expert_comm: > if parallel_mode == "column": > output_size = divide(output_size, tp_size) > elif parallel_mode == "row": > input_size = divide(input_size, tp_size) > parallel_mode = None > tp_size = 1 > tp_group = None > > super().__init__( > num_gemms=num_gemms, > in_features=input_size, > out_features=output_size, > sequence_parallel=self.config.sequence_parallel, > fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, > tp_group=tp_group, > tp_size=tp_size, > get_rng_state_tracker=( > get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None > ), > init_method=condition_init_method(config, init_method), > bias=bias, > return_bias=self.te_return_bias, > parallel_mode=parallel_mode, > **extra_kwargs, > ) > > for param in self.parameters(): > setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) > > def forward(self, x, m_splits): > _is_first_microbatch = ( > None if self.disable_parameter_transpose_cache else self.is_first_microbatch > ) > out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) > self.is_first_microbatch = False > > # TE only returns a tuple when return_bias is True, otherwise > # it returns a single Tensor, we always want to return two > # values regardless of the arguments. > if self.te_return_bias: > return out > return out, None > > def _sharded_state_dict_grouped( > self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None > ): > """ > prefix should be module_name to make keys identical to sequetial ones. > """ > sharded_state_dict = {} > full_state_dict = self.state_dict(prefix='', keep_vars=True) > num_global_experts = ( > parallel_state.get_expert_model_parallel_world_size() * self.num_gemms > ) > local_expert_indices_offset = ( > parallel_state.get_expert_model_parallel_rank() * self.num_gemms > ) > ep_axis = len(sharded_offsets) > for gemm_idx in range(self.num_gemms): > state_dict = { > f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'], > f'{gemm_idx}._extra_state': full_state_dict['_extra_state'], > } > if self.use_bias: > state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}'] > sub_sd = make_sharded_tensors_for_checkpoint( > state_dict, > '', > tp_axis_map, > ( > *sharded_offsets, > (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts), > ), > ) > # Remove expert layers indexing from sharded keys > replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix) > sharded_state_dict.update( > { > f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'], > # TODO: TE's GroupedLinear only has one _extra_state for all experts. > # We need sharding or build/merge fn to handle _extra_state correctly. > f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[ > f'{gemm_idx}._extra_state' > ], > } > ) > if self.use_bias: > sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias'] > # Adjust replica ids - replication along DP modulo EP > for k, sh_ten in sharded_state_dict.items(): > replica_id = sh_ten.replica_id > assert ( > len(replica_id) == 3 > ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' > sh_ten.replica_id = ( > *replica_id[:2], > parallel_state.get_data_modulo_expert_parallel_rank(), > ) > return sharded_state_dict > > class TEColumnParallelGroupedLinear(TEGroupedLinear): > """ > Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized > to column-parallel style. > """ > > def __init__( > self, > num_gemms: int, > input_size: int, > output_size: int, > *, > config: ModelParallelConfig, > init_method: Callable, > bias: bool, > skip_bias_add: bool, > is_expert: bool, > tp_comm_buffer_name: str = None, > ): > > super().__init__( > num_gemms=num_gemms, > input_size=input_size, > output_size=output_size, > parallel_mode="column", > config=config, > init_method=condition_init_method(config, init_method), > bias=bias, > skip_bias_add=skip_bias_add, > is_expert=is_expert, > tp_comm_buffer_name=tp_comm_buffer_name, > ) > > def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): > """ > For each gemm, sharding along axis 0, bias sharded. > Assume sharded_offsets[-1] is the expert parallel offset. > """ > tp_axis_map = {} > for gemm_idx in range(self.num_gemms): > tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) > return super()._sharded_state_dict_grouped( > tp_axis_map, prefix, sharded_offsets, metadata > ) > > class TERowParallelGroupedLinear(TEGroupedLinear): > """ > Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized > to row-parallel style. > """ > > def __init__( > self, > num_gemms: int, > input_size: int, > output_size: int, > *, > config: ModelParallelConfig, > init_method: Callable, > bias: bool, > skip_bias_add: bool, > is_expert: bool, > tp_comm_buffer_name: str = None, > ): > > super().__init__( > num_gemms=num_gemms, > input_size=input_size, > output_size=output_size, > parallel_mode="row", > config=config, > init_method=condition_init_method(config, init_method), > bias=bias, > skip_bias_add=skip_bias_add, > is_expert=is_expert, > tp_comm_buffer_name=tp_comm_buffer_name, > ) > > def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): > """ > For each gemm, sharding along axis 1, bias not sharded. > Assume sharded_offsets[-1] is the expert parallel offset. > """ > tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)} > return super()._sharded_state_dict_grouped( > tp_axis_map, prefix, sharded_offsets, metadata > ) > > else: > > TEGroupedLinear = None > TEColumnParallelGroupedLinear = None > TERowParallelGroupedLinear = None > > > class TEDelayedScaling(te.common.recipe.DelayedScaling): > """ > Wrapper for the Transformer-Engine's `DelayedScaling` layer. > """ > > def __init__( > self, > config: ModelParallelConfig, > fp8_format: int, > override_linear_precision: tuple = (False, False, False), > ): > extra_kwargs = _get_extra_te_kwargs(config) > if _te_version >= packaging.version.Version("1.6.0.dev0"): > extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention > extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention > > super().__init__( > margin=config.fp8_margin, > interval=config.fp8_interval, > fp8_format=fp8_format, > amax_compute_algo=config.fp8_amax_compute_algo, > amax_history_len=config.fp8_amax_history_len, > override_linear_precision=override_linear_precision, > **extra_kwargs, > ) > > > def te_checkpoint( > forward_func, > distribute_saved_activations, > get_rng_state_tracker, > tp_group, > hidden_states, > attention_mask, > context, > context_mask, > rotary_pos_emb, > ): > from transformer_engine.pytorch.distributed import checkpoint > > if _te_version >= packaging.version.Version("1.5.0"): > return checkpoint( > forward_func, > hidden_states, > attention_mask, > context, > context_mask, > rotary_pos_emb, > distribute_saved_activations=distribute_saved_activations, > get_rng_state_tracker=get_rng_state_tracker, > tp_group=tp_group, > ) > else: > return checkpoint( > forward_func, > distribute_saved_activations, > get_rng_state_tracker, > tp_group, > hidden_states, > attention_mask, > context, > context_mask, > rotary_pos_emb, > ) > > > try: > > from transformer_engine.pytorch.attention import _SplitAlongDim > > SplitAlongDim = _SplitAlongDim.apply > > except ImportError: > > SplitAlongDim = None > > try: > > from transformer_engine.pytorch.cpu_offload import ( > get_cpu_offload_context as _get_cpu_offload_context, > ) > > def get_cpu_offload_context( > enabled, num_layers, model_layers, activation_offloading, weight_offloading > ): > if _te_version > packaging.version.Version("1.8.0"): > context, sync_func = _get_cpu_offload_context( > enabled, num_layers, model_layers, activation_offloading, weight_offloading > ) > else: > context, sync_func = _get_cpu_offload_context( > enabled, num_layers, activation_offloading, weight_offloading > ) > > return context, sync_func > > except ImportError: > > get_cpu_offload_context = None diff -rN ./megatron/core/transformer/module.py ../megatron-lm/megatron/core/transformer/module.py 91,100c91,94 < """Sets the is_first_microbatch flag if it exists and config.fp8==True. < When this flag is set, TE modules will update their fp8 parameter cache. < """ < if self.config.fp8 is not None: < if not hasattr(self, "modules_with_is_first_microbatch"): < self.modules_with_is_first_microbatch = [] < for m in self.modules(): < if hasattr(m, "is_first_microbatch"): < self.modules_with_is_first_microbatch.append(m) < for m in self.modules_with_is_first_microbatch: --- > """Sets the is_first_microbatch flag if it exists. When this flag is set, TE modules will > update their fp8 parameter cache.""" > for m in self.modules(): > if hasattr(m, "is_first_microbatch"): diff -rN ./megatron/core/transformer/moe/experts.py ../megatron-lm/megatron/core/transformer/moe/experts.py 5d4 < from math import ceil 38c37 < """An efficient implementation of the Experts layer using GroupedGEMM. --- > """An efficient implementation of the Experts layer using CUTLASS GroupedGEMM. 40c39 < Executes multiple experts in parallel to maximize computational efficiency. --- > This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency. 50c49 < ), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead." --- > ), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead." 165,166c164 < def forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor): < """Forward step of the GroupedMLP.""" --- > def forward(self, permuted_local_hidden_states, tokens_per_expert): 183c181 < # Make sure params of experts still have gradients even given zero tokens. --- > # Make sure parameters still have gradients when no tokens are routed to this set of experts. 348c346 < Executes multiple experts in parallel to maximize computational efficiency. --- > This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency. 357c355 < # Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf --- > # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf 504,532c502 < def _pad_tensor_for_fp8(self, hidden): < """Padding tensor shape to multiples of 16.""" < actual_num_tokens = hidden.shape[0] < divisor = 16 < padded_num_tokens = ceil(actual_num_tokens / divisor) * divisor - actual_num_tokens < if padded_num_tokens > 0: < pad_tensor = torch.zeros( < padded_num_tokens, hidden.shape[1], dtype=hidden.dtype, device=hidden.device < ) < hidden = torch.cat((hidden, pad_tensor), dim=0) < return hidden < < def forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor): < """Forward step of the SequentialMLP.""" < if self.num_local_experts == 1: < if self.config.fp8: < hidden = self._pad_tensor_for_fp8(permuted_local_hidden_states) < output, output_bias = self.local_experts[0](hidden) < output = output[: permuted_local_hidden_states.shape[0]] < else: < output, output_bias = self.local_experts[0](permuted_local_hidden_states) < < return output, output_bias < else: < tokens_per_expert = tokens_per_expert.tolist() < tokens_list = torch.split(permuted_local_hidden_states, tokens_per_expert) < < output_local_list = [] < output_bias_list = [] --- > def forward(self, permuted_local_hidden_states, tokens_per_expert): 534,543c504,517 < for expert, tokens in zip(self.local_experts, tokens_list): < if self.config.fp8: < hidden = self._pad_tensor_for_fp8(tokens) < output, output_bias = expert(hidden) < output = output[: tokens.shape[0]] < else: < output, output_bias = expert(tokens) < output_local_list.append(output) < if self.add_bias: < output_bias_list.append(output_bias.expand_as(output)) --- > output_local = torch.zeros_like(permuted_local_hidden_states) > output_bias_local = None > if self.add_bias: > output_bias_local = torch.zeros_like(permuted_local_hidden_states) > > cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) > # Insert zero at the begining for offset index's convenience > zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) > cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) > for expert_num, expert in enumerate(self.local_experts): > start = cumsum_num_tokens[expert_num] > end = cumsum_num_tokens[expert_num + 1] > hidden = permuted_local_hidden_states[start:end] > output, output_bias = expert(hidden) 545c519 < output_local = torch.cat(output_local_list, dim=0) --- > output_local[start:end] = output 547,549c521,522 < output_bias_local = torch.cat(output_bias_list, dim=0) < else: < output_bias_local = None --- > output_bias = output_bias.expand_as(output) > output_bias_local[start:end, :] = output_bias 551c524 < return output_local, output_bias_local --- > return output_local, output_bias_local diff -rN ./megatron/core/transformer/moe/moe_utils.py ../megatron-lm/megatron/core/transformer/moe/moe_utils.py 21,24c21,22 < probs (torch.Tensor): Softmax probabilities output by the router for each token. < Shape in [num_tokens, num_experts]. < tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. < Shape in [num_experts] --- > probs (torch.Tensor): Softmax probabilities output by the router for each token. [num_tokens, num_experts] > tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. [num_experts] 27,29c25 < sequence_partition_group (optional): The parallel group over which the sequence is < partitioned. If None, no partitioning is applied. < Defaults to None. --- > sequence_partition_group (optional): The parallel group over which the sequence is partitioned. If None, no partitioning is applied. Defaults to None. 36,38c32 < # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism < # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full < # sequence. --- > # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full sequence. 40,41c34 < # We can keep `aggregated_probs_per_expert` local since we don't need the gradient for < # `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`. --- > # We can keep `aggregated_probs_per_expert` local since we don't need the gradient for `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`. 48,49c41 < # The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) * < # (tokens_per_expert/(num_tokens*topk))) * num_experts * moe_aux_loss_coeff. --- > # The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) * (tokens_per_expert/(num_tokens*topk))) * num_experts * moe_aux_loss_coeff. 136,137c128 < Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss < gradient. --- > Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. 149,150c140 < scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in < matches the scale of the main_loss. --- > scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss. 157,158c147 < The input indices shape is [tokens, top_k], it indicates which experts were selected by each < token separately. --- > The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. 161,169c150,152 < indices (torch.Tensor): The token to expert indices tensor, should have a shape of < [num_tokens] or [num_tokens, topk]. < num_out_tokens (int, optional): The effective output token count, when enabling the < capacity factor, should equal the number of tokens not < dropped. By default, set to None, meaning no tokens are < dropped. < padded_mode (bool, optional): If True, indicating the indices are padded to < [num_expert, capacity] to denote selected tokens per expert. < Defaults to False. --- > indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk]. > num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped. > padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False. 179,181c162,164 < indices = indices.unsqueeze(1) < < topk = indices.size(1) --- > topk = 1 > else: > topk = indices.size(1) 186,188c169 < moe_gather_indices = (sorted_indices // topk).unsqueeze(1).expand(-1, tokens.size(-1)) < permuted_tokens = moe_gather.apply(tokens, moe_gather_indices) < --- > permuted_tokens = tokens.index_select(0, sorted_indices // topk) 199,200c180 < """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the < tokens with their corresponding probabilities. --- > """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities. 203,215c183,187 < permuted_tokens (torch.Tensor): 2D tensor [num_tokens*topk, hidden]. The tensor of permuted < tokens to be unpermuted. < sorted_indices (torch.Tensor): 1D tensor [num_tokens*topk]. The tensor of sorted indices < used to unpermute the tokens. < probs (torch.Tensor, optional): 2D tensor [num_tokens, topk]. The tensor of probabilities < corresponding to the permuted tokens. If provided, < the unpermuted tokens will be merged with their respective < probabilities. < padded_mode (bool, optional): If True, indicating the indices are padded to < [num_expert, capacity] to denote selected tokens per expert. < Defaults to False. < restore_shape (torch.Size, optional): The input shape before permutation, only used in < padding mode. Defaults to None. --- > permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted. > sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens. > probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. > padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False. > restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None. 231d202 < assert probs.dim() == 2, f"Expected 2D tensor for probs, got {probs.dim()} dims." 238,240c209,214 < output_size = [num_unpermuted_tokens, permuted_tokens.shape[-1]] < moe_scatter_indices = sorted_indices.unsqueeze(1).expand(-1, permuted_tokens.size(-1)) < unpermuted_tokens = moe_scatter.apply(permuted_tokens, moe_scatter_indices, output_size) --- > unpermuted_tokens = torch.zeros( > [num_unpermuted_tokens, permuted_tokens.shape[-1]], > dtype=permuted_tokens.dtype, > device=permuted_tokens.device, > ) > unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens) 251,252c225 < The input indices shape is [num_expert, capacity], it indicates which tokens were selected < by each expert separately. --- > The input indices shape is [num_expert, capacity], it indicates which tokens were selected by each expert separately. 255,256c228 < indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected < tokens for each expert. --- > indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert. 274,275c246 < Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their < corresponding probabilities. --- > Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their corresponding probabilities. 277,278c248 < This function takes a tensor of permuted tokens and reorders them according to the provided < indices. It also combines the tokens with their associated probabilities. --- > This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities. 282,285c252,253 < indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected < tokens for each expert. < probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities < corresponding to each token. --- > indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert. > probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token. 330d297 < deterministic_mode: bool = False, 336,337c303 < capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number < of tokens exceeds the capacity. --- > capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity. 339,341c305 < drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". < If "prob", the tokens with the lowest probabilities will be dropped. < If "position", tokens at the end of each batch will be dropped. --- > drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. 344,345c308 < Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert < tensor. --- > Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor. 347,350c310,311 < (1) If there's no token padding, the shape of probs and indices is [tokens, top_k], < indicating the selected experts for each token. < (2) If there's token padding, the shape of probs and indices is [num_expert, capacity], < indicating the tokens selected for each expert. --- > (1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token. > (2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert. 362,363c323 < # Requires applying softmax before selecting the top-k when k is 1, < # since softmax on a [num_tokens, 1] would yield a zero gradient. --- > # Requires applying softmax before selecting the top-k when k is 1, since softmax on a [num_tokens, 1] would yield a zero gradient. 370,373c330 < if deterministic_mode: < tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts) < else: < tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts) --- > tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts) 546c503,505 < output = torch.zeros(output_size, dtype=input_.dtype, device=input_.device) --- > output = torch.zeros( > output_size, dtype=input_.dtype, device=torch.cuda.current_device() > ) Binary files ./megatron/core/transformer/moe/__pycache__/experts.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/experts.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/grouped_gemm_util.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/grouped_gemm_util.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/legacy_a2a_token_dispatcher.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/legacy_a2a_token_dispatcher.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/moe_layer.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/moe_layer.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/moe_utils.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/moe_utils.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/router.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/router.cpython-310.pyc differ Binary files ./megatron/core/transformer/moe/__pycache__/token_dispatcher.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/moe/__pycache__/token_dispatcher.cpython-310.pyc differ diff -rN ./megatron/core/transformer/moe/README.md ../megatron-lm/megatron/core/transformer/moe/README.md 64d63 < | --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.| 121,126d119 < ### Upcycling < < Use `--moe-use-upcycling` to enable the upcycling feature, which will load the dense model from the directory specified by `--load`, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model. < < The MoE model structure is defined through script arguments. All MoE-related arguments (such as `--num-experts`) can be customized; however, other model structure arguments must be consistent with those of the dense model. < 252c245 < ### Tuning Guide of Parallel Mappings --- > ### Tuning Guide of Paralell Mappings diff -rN ./megatron/core/transformer/moe/router.py ../megatron-lm/megatron/core/transformer/moe/router.py 77,78c77 < Tuple[torch.Tensor, torch.Tensor]: < Tuple of tensors representing max probs and the indices. --- > Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices. 159d157 < deterministic_mode=self.config.deterministic_mode, 177,180c175,176 < probs (torch.Tensor): < The probs output by the router for each token. [num_tokens, num_experts] < num_local_tokens_per_expert (torch.Tensor): < The number of tokens per expert. [num_experts] --- > probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts] > num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts] 286d281 < deterministic_mode=self.config.deterministic_mode, diff -rN ./megatron/core/transformer/moe/token_dispatcher.py ../megatron-lm/megatron/core/transformer/moe/token_dispatcher.py 23,32d22 < """ We use the following notation throughout this file: < H: hidden size < B: micro batch size < S: sequence length < TP: tensor model parallel size < EP: expert model parallel size < num_local_tokens: S/TP*B < num_global_tokens: num_local_tokens*TP*EP < """ < 78d67 < Note that this allgather spans the communication domain of TP*EP: 97a87,90 > # self.indices: The indices of `local_indices` (which holds the un-sorted expert indices of > # tokens that local expert can process) that give its sorted order along dim 0. > self.indices = None > 115,121c108,110 < hidden_states: 3D tensor [S/TP, B, H]. Input tokens. < max_prob: 2D tensor [S/TP*B, topk]. Each row of max_prob contains < the probility distribution across `topk` experts for one local token. < For 'aux_loss' load balancing, the sum of the values in each row is 1, < thus for `top1` gating, it degenerates into a full 1 tensor. < max_ind: 2D tensor [num_local_tokens, topk], where < `num_local_tokens=S/TP*B`. Token assignment to global experts. --- > hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize] > max_prob: probs of local token assignment to global experts. > max_ind: token assignment to local experts. 135d123 < ## local_indices calculation 137,138d124 < # [num_local_tokens, topk] -> [num_global_tokens, topk], where: < # num_local_tokens=(S/TP)*B, num_global_tokens=S*B*EP 149,155c135,141 < ## local_probs calculation < # max_prob: [S/TP*B, topk] -> global_probs: [S*B*EP, topk] < global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob) < self.local_probs = global_probs.masked_select(global_local_mask) < self.local_probs = self.local_probs.view(-1, 1) < # Note that this allgather spans the communication domain of TP*EP. < # [(S/TP)*B, H] -> [((S/TP)*B)*(TP*EP), H] = [S*B*EP, H] --- > if self.router_topk > 1: # k > 1 > global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob) > self.local_probs = global_probs.masked_select(global_local_mask) > else: > self.local_probs = max_prob > > # [S*B/TP, H] -> [S*B, H] 168d153 < self.local_probs = self.local_probs.view(-1, 1) 176c161 < self.local_probs = max_prob.view(-1, 1) --- > self.local_probs = max_prob 183,197c168,174 < if self.config.deterministic_mode: < tokens_per_expert = torch.bincount( < local_indices.view(-1), minlength=self.config.num_moe_experts < ) < if self.num_local_experts < self.config.num_moe_experts: < tokens_per_expert = tokens_per_expert[ < self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 < ] < else: < tokens_per_expert = torch.histc( < local_indices, < bins=self.num_local_experts, < min=self.local_expert_indices[0], < max=self.local_expert_indices[-1], < ) --- > tokens_per_expert = torch.bincount( > local_indices.view(-1), minlength=self.config.num_moe_experts > ) > if self.num_local_experts < self.config.num_moe_experts: > tokens_per_expert = tokens_per_expert[ > self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 > ] 202,207c179,184 < < permuted_local_hidden_states, self.reversed_local_input_permutation_mapping = permute( < local_hidden_states, local_indices < ) < < return permuted_local_hidden_states, tokens_per_expert --- > self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1]) > if self.num_local_experts > 1: > permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices) > else: > permuted_local_hidden_states = local_hidden_states > return (permuted_local_hidden_states, tokens_per_expert) 211c188 < Reverse process of `dispatch()` which permutes the output of local --- > Reverse process of `dispatch()` which permutes the ouput of local 216,217c193,194 < hidden_states: 2D tensor [num_permuted_tokens_for_local_experts, H], < output of local experts. --- > hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize], > ouput of local experts. 222c199 < with shape of [S/TP, B, H] --- > with shape of [SeqLen/TP, MBS, HiddenSize] 225c202,207 < # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1. --- > scores = self.local_probs.to(dtype=hidden_states.dtype) > if self.num_local_experts > 1: > assert self.indices.shape == hidden_states.shape > unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices) > else: > unpermuted_local_hidden = hidden_states 227,230c209,211 < unpermuted_local_hidden = unpermute( < hidden_states, self.reversed_local_input_permutation_mapping < ) < unpermuted_local_hidden = unpermuted_local_hidden * self.local_probs --- > # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1. > if self.router_topk > 1: > unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1) 236,237c217,220 < unpermuted_local_bias = unpermute(bias, self.reversed_local_input_permutation_mapping) < unpermuted_local_bias = unpermuted_local_bias * self.local_probs --- > assert self.indices.shape == bias.shape > unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias) > if self.router_topk > 1: > unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1) 250c233 < # hidden_shape: [S/TP, B, H], gloal_num_tokens = S/TP*B*(TP*EP) --- > # hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP) 293a277,278 > if self.router_topk == 1: > output_total = output_total * scores 295a281,283 > assert output_bias_total is not None > if self.router_topk == 1: > output_bias_total = output_bias_total * scores 387,394c375 < if self.config.deterministic_mode: < num_local_tokens_per_expert = torch.bincount( < indices.view(-1), minlength=self.num_experts < ) < else: < num_local_tokens_per_expert = torch.histc( < indices, bins=self.num_experts, min=0, max=self.num_experts < ) --- > num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts) 512c493 < self.hidden_shape_before_permute = hidden_states.shape --- > self.hiddden_shape_before_permute = hidden_states.shape 601c582 < restore_shape=self.hidden_shape_before_permute, --- > restore_shape=self.hiddden_shape_before_permute, diff -rN ./megatron/core/transformer/moe/upcycling_utils.py ../megatron-lm/megatron/core/transformer/moe/upcycling_utils.py 1,196d0 < # Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. < """ Helpers for converting a dense model to a MoE model in runtime """ < from megatron.core import mpu < < < def _get_keys_endswith(model, suffix): < """ < Retrieve keys from the model that end with a specified suffix. < """ < return [k for k in model if k.endswith(suffix)] < < < def _covert_to_moe_state_dict(state_dict, moe_model): < """ < Convert a dense model's state_dict to a MoE model's state_dict. < < This function takes the state dictionary of a dense model and modifies it to fit the < structure required by a Mixture of Experts model. It handles the necessary < transformations for weights and biases specific to the MoE architecture. < < Args: < state_dict (dict): The dense model's state_dict. < moe_model (nn.Module): The MoE model instance from which to get the submodule < and state_dict, must be a model without FP16 and/or < DDP wrapper. < < Returns: < dict: The converted MoE model state_dict, ready for use in the MoE architecture. < """ < < mlp = moe_model.get_submodule('decoder.layers.0.mlp') < < moe_state_dict = moe_model.state_dict() < new_state_dict = state_dict < < mlp_lm_weight_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc1.layer_norm_weight') < mlp_lm_bias_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc1.layer_norm_bias') < mlp_fc1_weight_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc1.weight') < mlp_fc2_weight_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc2.weight') < mlp_fc1_bias_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc1.bias') < mlp_fc2_bias_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc2.bias') < mlp_fc1_extra_state_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc1._extra_state') < mlp_fc2_extra_state_keys = _get_keys_endswith(new_state_dict, 'mlp.linear_fc2._extra_state') < < for key in mlp_lm_weight_keys: < params = new_state_dict.pop(key) < new_key = key.replace('mlp.linear_fc1.layer_norm_weight', 'pre_mlp_layernorm.weight') < new_state_dict[new_key] = params < < for key in mlp_lm_bias_keys: < params = new_state_dict.pop(key) < new_key = key.replace('mlp.linear_fc1.layer_norm_bias', 'pre_mlp_layernorm.bias') < new_state_dict[new_key] = params < < for mlp_weight_key in mlp_fc1_weight_keys: < router_key = mlp_weight_key.replace('mlp.linear_fc1.weight', 'mlp.router.weight') < new_state_dict[router_key] = moe_state_dict[router_key].data.data.clone() < < use_te_grouped_gemm = 'decoder.layers.0.mlp.experts.linear_fc1.weight0' in moe_state_dict < < if mlp.config.moe_grouped_gemm and use_te_grouped_gemm: < for mlp_weight_key in mlp_fc1_weight_keys: < weight_tensor = new_state_dict.pop(mlp_weight_key) < for expert_i in range(mlp.num_local_experts): < new_key = mlp_weight_key.replace( < 'mlp.linear_fc1.weight', f'mlp.experts.linear_fc1.weight{expert_i}' < ) < new_state_dict[new_key] = weight_tensor.clone() < < for mlp_weight_key in mlp_fc2_weight_keys: < weight_tensor = new_state_dict.pop(mlp_weight_key) < for expert_i in range(mlp.num_local_experts): < new_key = mlp_weight_key.replace( < 'mlp.linear_fc2.weight', f'mlp.experts.linear_fc2.weight{expert_i}' < ) < new_state_dict[new_key] = weight_tensor.clone() < < for extra_state_key in mlp_fc1_extra_state_keys: < new_state_dict.pop(extra_state_key) < new_key = extra_state_key.replace( < 'mlp.linear_fc1._extra_state', 'mlp.experts.linear_fc1._extra_state' < ) < new_state_dict[new_key] = None < < for extra_state_key in mlp_fc2_extra_state_keys: < new_state_dict.pop(extra_state_key) < new_key = extra_state_key.replace( < 'mlp.linear_fc2._extra_state', 'mlp.experts.linear_fc2._extra_state' < ) < new_state_dict[new_key] = None < < elif mlp.config.moe_grouped_gemm: < for mlp_weight_key in mlp_fc1_weight_keys: < weight_tensor = new_state_dict.pop(mlp_weight_key) < shape = weight_tensor.shape < weight_tensor = weight_tensor.repeat(mlp.num_local_experts, 1, 1) < weight_tensor = weight_tensor.permute(0, 2, 1).reshape( < shape[1], mlp.num_local_experts * shape[0] < ) < new_key = mlp_weight_key.replace('mlp.linear_fc1.weight', 'mlp.experts.weight1') < new_state_dict[new_key] = weight_tensor < < for mlp_weight_key in mlp_fc2_weight_keys: < weight_tensor = new_state_dict.pop(mlp_weight_key) < shape = weight_tensor.shape < weight_tensor = weight_tensor.repeat(mlp.num_local_experts, 1, 1) < weight_tensor = weight_tensor.permute(0, 2, 1).reshape( < mlp.num_local_experts * shape[1], shape[0] < ) < new_key = mlp_weight_key.replace('mlp.linear_fc2.weight', 'mlp.experts.weight2') < new_state_dict[new_key] = weight_tensor < < else: < < def covert_to_experts(keys): < for key in keys: < params = new_state_dict.pop(key) < new_key_format_str = key.replace('mlp', 'mlp.experts.local_experts.{}') < for expert_i in range(mlp.num_local_experts): < new_key = new_key_format_str.format(expert_i) < if hasattr(params, 'clone'): < new_state_dict[new_key] = params.clone() < else: < # set extra_state to None for now < new_state_dict[new_key] = None < < covert_to_experts(mlp_fc1_weight_keys) < covert_to_experts(mlp_fc2_weight_keys) < covert_to_experts(mlp_fc1_bias_keys) < covert_to_experts(mlp_fc2_bias_keys) < covert_to_experts(mlp_fc1_extra_state_keys) < covert_to_experts(mlp_fc2_extra_state_keys) < < return new_state_dict < < < def upcycle_state_dict(moe_model, dense_model): < """ < Convert a dense model's state_dict to a MoE model's state_dict. < < This function facilitates the conversion of the state_dict from a dense model to < a MoE model, ensuring that the parameters are correctly mapped for each model. < < Args: < moe_model (nn.Module): The MoE model, must be a model without FP16 and/or DDP wrapper. < dense_model (nn.Module): The dense model instance. < < Returns: < dict: A dictionary containing the converted state_dict for the MoE model. < """ < < state_dict = {} < if len(moe_model) == 1: < assert len(dense_model) == 1 < state_dict['model'] = _covert_to_moe_state_dict(dense_model[0].state_dict(), moe_model[0]) < else: < assert len(moe_model) == len(dense_model) < for i in range(len(moe_model)): < mpu.set_virtual_pipeline_model_parallel_rank(i) < state_dict['model%d' % i] = _covert_to_moe_state_dict( < dense_model[i].state_dict(), moe_model[i] < ) < return state_dict < < < def load_and_upcycle_model( < load_dense_ckpt_func, moe_model, dense_model, strict=True, load_args=(), load_kwargs={} < ): < """ < Load a dense model checkpoint and convert it to a MoE model. < < This function loads a checkpoint for a dense model and converts it to the MoE model format, < allowing for the integration of the dense model's parameters into the MoE architecture. < < Args: < load_dense_ckpt_func (callable): The function to load the dense model checkpoint. < moe_model (nn.Module): The MoE model instance. < dense_model (nn.Module): The dense model instance. < strict (bool): Whether to strictly load the state dictionary (default is True). < load_args (tuple): Positional arguments to pass to the loading function. < load_kwargs (dict): Keyword arguments to pass to the loading function. < """ < < iteration, num_floating_point_operations_so_far = load_dense_ckpt_func( < *load_args, **load_kwargs < ) < state_dict = upcycle_state_dict(moe_model, dense_model) < < if len(moe_model) == 1: < moe_model[0].load_state_dict(state_dict['model'], strict=strict) < else: < for i in range(len(moe_model)): < mpu.set_virtual_pipeline_model_parallel_rank(i) < moe_model[i].load_state_dict(state_dict['model%d' % i], strict=strict) < < return iteration, num_floating_point_operations_so_far Binary files ./megatron/core/transformer/__pycache__/attention.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/attention.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/dot_product_attention.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/dot_product_attention.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/enums.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/enums.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/identity_op.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/identity_op.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/mlp.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/mlp.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/module.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/module.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/spec_utils.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/spec_utils.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/transformer_block.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/transformer_block.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/transformer_config.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/transformer_config.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/transformer_layer.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/transformer_layer.cpython-310.pyc differ Binary files ./megatron/core/transformer/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/core/transformer/__pycache__/utils.cpython-310.pyc differ diff -rN ./megatron/core/transformer/transformer_block.py ../megatron-lm/megatron/core/transformer/transformer_block.py 23c23 < from megatron.core.extensions.transformer_engine import ( --- > from megatron.core.transformer.custom_layers.transformer_engine import ( 48,84c48,51 < """ < Determine the number of transformer layers to build for the current pipeline stage. < Args: < config (TransformerConfig): Configuration object containing transformer model parameters. < < Returns: < int: The number of layers to be built for the current pipeline stage. < """ < if config.first_pipeline_num_layers is not None or config.last_pipeline_num_layers is not None: < assert ( < parallel_state.get_virtual_pipeline_model_parallel_world_size() is None < ), "Uneven number of layer not compatible with interleaved pipeline schedule" < < # Number of layers to distribute over rest of pipeline stages < layers_to_distribute = config.num_layers < # Number of pipeline stages left for distributing transformer layers < pipeline_stages_left = parallel_state.get_pipeline_model_parallel_world_size() < < if config.first_pipeline_num_layers is not None: < layers_to_distribute -= config.first_pipeline_num_layers < pipeline_stages_left -= 1 < if parallel_state.is_pipeline_first_stage(): < return config.first_pipeline_num_layers < < if config.last_pipeline_num_layers is not None: < layers_to_distribute -= config.last_pipeline_num_layers < pipeline_stages_left -= 1 < if parallel_state.is_pipeline_last_stage(): < return config.last_pipeline_num_layers < < assert ( < layers_to_distribute % pipeline_stages_left == 0 < ), "With uneven pipelineing the left over layers must be divisible by left over stages" < num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left < else: < pipeline_ranks = config.pipeline_model_parallel_size < num_layers_per_pipeline_rank = config.num_layers // pipeline_ranks --- > > pipeline_ranks = config.pipeline_model_parallel_size > > num_layers_per_pipeline_rank = config.num_layers // pipeline_ranks 116,129d82 < """ < Dataclass for specifying the submodules of a transformer block. < < This class defines the structure for configuring the layers and normalization < within a transformer block, allowing for flexible and customizable architecture designs. < < Args: < layer_specs (List[ModuleSpec], optional): A list of module specifications for < the layers within the transformer block. Each specification typically < defines a complete transformer layer (e.g., self-attention, feed-forward network). < layer_norm (Optional[Union[ModuleSpec, torch.nn.Module]], optional): Specification < or instance of the layer normalization to be applied. < """ < 137,148d89 < """ < Retrieve or construct TransformerBlockSubmodules based on the provided specification. < < Args: < config (TransformerConfig): Configuration object for the transformer model. < spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the < transformer block submodules. Can be either a TransformerBlockSubmodules < instance or a ModuleSpec. < < Returns: < TransformerBlockSubmodules: The submodules for the transformer block. < """ 223d163 < self.tp_only_amax_red = config.tp_only_amax_red 370,392c310,311 < """ < Perform the forward pass through the transformer block. < < This method handles the core computation of the transformer, including < self-attention, optional cross-attention, and feed-forward operations. < < Args: < hidden_states (Tensor): Input tensor of shape [s, b, h] where s is the < sequence length, b is the batch size, and h is the hidden size. < attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking < self-attention. < context (Tensor, optional): Context tensor for cross-attention. < context_mask (Tensor, optional): Mask for cross-attention context < rotary_pos_emb (Tensor, optional): Rotary positional embeddings. < inference_params (InferenceParams, optional): Parameters for inference-time < optimizations. < packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence < processing. < < Returns: < Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape < [s, b, h], and optionally the updated context tensor if cross-attention is used. < """ --- > # hidden_states (float): [s, b, h] > # attention_mask (bool): [1, 1, s, s] 437,439c356 < fp8_group = parallel_state.get_amax_reduction_group( < with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red < ) --- > fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) 460d376 < layer.use_cudagraph = True 470a387,392 > # CUDA graph doesn't output context and is expected to be None > assert ( > (context is None) > or (not self.config.enable_cuda_graph) > or (not self.training) > ) 507,519d428 < """ < Generate a sharded state dictionary for the transformer block. < < Args: < prefix (str, optional): Prefix to be added to all keys in the state dict. < Defaults to an empty string. < sharded_offsets (tuple, optional): Tuple of sharding offsets. < metadata (dict, optional): Additional metadata for sharding. < Can specify if layers are non-homogeneous. Defaults to None. < < Returns: < ShardedStateDict: A dictionary containing the sharded state of the model. < """ diff -rN ./megatron/core/transformer/transformer_config.py ../megatron-lm/megatron/core/transformer/transformer_config.py 9c9 < from ..utils import get_te_version, init_method_normal, is_te_min_version, scaled_init_method_normal --- > from ..utils import init_method_normal, scaled_init_method_normal 26,33d25 < first_pipeline_num_layers: int = None < """Number of transformer layers on first pipeline stage. < None implies equal layer division across PP ranks.""" < < last_pipeline_num_layers: int = None < """Number of transformer layers on last pipeline stage. < None implies equal layer division across PP ranks.""" < 168a161 > recompute_granularity: str = None 207,209c200 < """DEPRECATED from TransformerEngine v1.8.0. This flag is ignored. < Controls how often the scaling factor is recomputed. < """ --- > """Controls how often the scaling factor is recomputed.""" 231,233d221 < tp_only_amax_red: bool = False < """When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain""" < 305,308c293 < """When set to true, TransformerLayer layers are swapped with a CUDA graphed version.""" < < external_cuda_graph: bool = False < """When set to true, TransformerLayer layers are swapped with user provided CUDA graphs.""" --- > """When set to true, TransformerLayer blocks are wrapped with CUDA graph.""" 478,488d462 < < if self.num_moe_experts and self.fp8: < # TE version below 1.7.0 will raise Error when handle zeros tokens for expert < if not is_te_min_version("1.7.0.dev0"): < raise ValueError( < "Only transformer-engine>=1.7.0 supports MoE FP8 training, " < f"but your version is {get_te_version()}." < ) < < if self.moe_grouped_gemm: < raise ValueError("Grouped GEMM of MoE not support fp8 for now.") diff -rN ./megatron/core/transformer/transformer_layer.py ../megatron-lm/megatron/core/transformer/transformer_layer.py 12d11 < from megatron.core.transformer.cuda_graphs import CudaGraphManager 22,47d20 < """ < Configuration class for specifying the submodules of a transformer layer. < < This class defines the structure and default implementations for various < components of a transformer layer, allowing for flexible customization < of the layer's architecture. < < Args: < input_layernorm (Union[ModuleSpec, type]): Specification for the input layer normalization. < self_attention (Union[ModuleSpec, type]): Specification for the self-attention mechanism. < self_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation < after self-attention. < pre_cross_attn_layernorm (Union[ModuleSpec, type]): Specification for the layer < normalization before cross-attention. < cross_attention (Union[ModuleSpec, type]): Specification for the cross-attention mechanism. < cross_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation < after cross-attention. < pre_mlp_layernorm (Union[ModuleSpec, type]): Specification for the layer normalization < before the MLP. < mlp (Union[ModuleSpec, type]): Specification for the MLP. < mlp_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation < after the MLP. < sharded_state_dict_keys_map (Dict[str, str]): Mapping for sharded tensor keys to be applied < in the `sharded_state_dict` method. < """ < 95,101d67 < < if config.enable_cuda_graph and self.training: < assert ( < not config.cpu_offloading and config.recompute_granularity is None < ), "Cudagraphs not supported" < self.cudagraph_manager = CudaGraphManager() < 102a69 > 166c133 < """Get the index number of this layer, given the level of pipelining.""" --- > 170c137 < self.config.num_layers // self.config.pipeline_model_parallel_size --- > self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() 185,235c152 < if ( < self.config.first_pipeline_num_layers is not None < or self.config.last_pipeline_num_layers is not None < ): < # Calculate number of pipelines for distributing layers < middle_pipeline_stages = parallel_state.get_pipeline_model_parallel_world_size() < middle_pipeline_stages -= sum( < [ < 1 if x is not None else 0 < for x in ( < self.config.first_pipeline_num_layers, < self.config.last_pipeline_num_layers, < ) < ] < ) < < # Calculate layers to distribute < first_pipeline_offset = ( < 0 < if self.config.first_pipeline_num_layers is None < else self.config.first_pipeline_num_layers < ) < last_pipeline_offset = ( < 0 < if self.config.last_pipeline_num_layers is None < else self.config.last_pipeline_num_layers < ) < < middle_num_layers = ( < self.config.num_layers - first_pipeline_offset - last_pipeline_offset < ) < < if middle_pipeline_stages > 0: < num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages < else: < num_layers_per_pipeline_rank = 0 < < middle_pipeline_rank = ( < pipeline_rank < if self.config.first_pipeline_num_layers is None < else pipeline_rank - 1 < ) < < if pipeline_rank == 0: < offset = 0 < else: < offset = ( < middle_pipeline_rank * num_layers_per_pipeline_rank < ) + first_pipeline_offset < else: < offset = pipeline_rank * num_layers_per_pipeline_rank --- > offset = pipeline_rank * num_layers_per_pipeline_rank 251,272c168 < """ < Perform a forward pass through the transformer layer. < < This method implements the core computation of a transformer layer, including < self-attention, cross-attention (if applicable), and feed-forward operations. < < Args: < hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length, < b is batch size, and h is hidden size. < attention_mask (Tensor): Mask tensor for self-attention. < context (Tensor, optional): Context tensor for cross-attention. < context_mask (Tensor, optional): Mask tensor for cross-attention. < rotary_pos_emb (Tensor, optional): Rotary positional embeddings. < inference_params (object, optional): Parameters for inference-time optimizations. < packed_seq_params (object, optional): Parameters for packed sequence processing. < < Returns: < Tuple[Tensor, Tensor]: A tuple containing: < output (Tensor): Transformed hidden states of shape [s, b, h]. < context (Tensor): Updated context tensor if cross-attention is used, < otherwise None. < """ --- > # hidden_states: [s, b, h] 351,361d246 < """ < Generate a sharded state dictionary for the transformer layer. < < Args: < prefix (str, optional): Prefix to be added to all keys in the state dict. < sharded_offsets (tuple, optional): Tuple of sharding offsets. < metadata (Optional[dict], optional): Additional metadata for sharding. < < Returns: < ShardedStateDict: A dictionary containing the sharded state of the transformer layer. < """ 370,374d254 < < def __call__(self, *args, **kwargs): < if hasattr(self, 'cudagraph_manager'): < return self.cudagraph_manager(self, args, kwargs) < return super(MegatronModule, self).__call__(*args, **kwargs) diff -rN ./megatron/core/utils.py ../megatron-lm/megatron/core/utils.py 18d17 < from importlib.metadata import version 23d21 < from packaging.version import Version as PkgVersion 31,57d28 < _te_version = None < < < def get_te_version(): < """Get TE version from __version__; if not available use pip's. Use caching.""" < < def get_te_version_str(): < import transformer_engine as te < < if hasattr(te, '__version__'): < return str(te.__version__) < else: < return version("transformer-engine") < < global _te_version < if _te_version is None: < _te_version = PkgVersion(get_te_version_str()) < return _te_version < < < def is_te_min_version(version, check_equality=True): < """Check if minimum version of `transformer-engine` is installed.""" < if check_equality: < return get_te_version() >= PkgVersion(version) < return get_te_version() > PkgVersion(version) < < 99d69 < """Returns model_type attribute""" 104d73 < """Returns whether the model has the xattn_needed attribute""" 112d80 < """Returns the config attribute, allowed to return None""" 125,127d92 < """ < Returns (potentially) a sub-tensor from the self.buffer for the given shape. < """ 141c106 < """Make a viewless tensor. --- > '''Make a viewless tensor. 148c113 < """ --- > ''' 155c120 < """ --- > ''' 162c127 < """ --- > ''' 166d130 < """Runs the fwd pass of _kernel_make_viewless_tensor""" 171d134 < """No-op""" 176c139 < """ --- > ''' 183c146 < """ --- > ''' 197,198c160,161 < """Assert that a tensor is not a view (i.e., its '._base' field is < not set).""" --- > '''Assert that a tensor is not a view (i.e., its '._base' field is > not set).''' 213c176 < """Safely set tensor's '.data' field. --- > '''Safely set tensor's '.data' field. 217c180 < """ --- > ''' 283,285c246 < def check_param_hashes_across_dp_replicas( < model: List[torch.nn.Module], cross_check: bool = False < ) -> bool: --- > def check_param_hashes_across_dp_replicas(model: List[torch.nn.Module]) -> bool: 287c248,249 < and then checks for equality between the locally-computed hashes and those of other ranks. --- > and then checks for equality between the locally-computed hashes and the hashes > from DP replica 0. 296d257 < cross_check (bool): If true, will check whether hashes match across all DP replicas. 299,300c260,261 < True if all param hashes match with corresponding hash on DP replica 0 or < across all replicas if cross_check is enabled, False otherwise. --- > True if all param hashes match with corresponding hash on DP replica 0, False > otherwise. 337,341c298 < if cross_check: < # Make sure all ranks have the same hash. < return all(map(lambda x: torch.equal(local_param_hashes, x), all_param_hashes)) < else: < return param_hashes_match --- > return param_hashes_match 399c356 < """Ensure grad_output is stored in a contiguous buffer.""" --- > 506d462 < """Multi tensor op applier""" 513,516d468 < """ < Computes l2 norm for a list of contiguous tensors < works as a drop-in replacement for amp_C.multi_tensor_l2norm < """ 525d476 < """Works as a drop-in replacement for amp_C.multi_tensor_scale.""" 1292,1307d1242 < < < # Check if Transformer Engine has Float8Tensor class < HAVE_TE_FLOAT8TENSOR = False < try: < from transformer_engine.pytorch.float8_tensor import Float8Tensor < < HAVE_TE_FLOAT8TENSOR = True < except (ImportError, ModuleNotFoundError): < # Float8Tensor not found < pass < < < def is_float8tensor(tensor: torch.Tensor) -> bool: < """Check if a tensor is a Transformer Engine Float8Tensor""" < return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) Binary files ./megatron/legacy/data/__pycache__/data_samplers.cpython-310.pyc and ../megatron-lm/megatron/legacy/data/__pycache__/data_samplers.cpython-310.pyc differ Binary files ./megatron/legacy/data/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/legacy/data/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/legacy/fused_kernels/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/legacy/fused_kernels/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/bert_model.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/bert_model.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/enums.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/enums.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/fused_bias_gelu.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/fused_bias_gelu.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/fused_layer_norm.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/fused_layer_norm.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/fused_softmax.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/fused_softmax.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/gpt_model.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/gpt_model.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/language_model.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/language_model.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/module.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/module.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/rms_norm.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/rms_norm.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/t5_model.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/t5_model.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/transformer.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/transformer.cpython-310.pyc differ Binary files ./megatron/legacy/model/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/legacy/model/__pycache__/utils.cpython-310.pyc differ diff -rN ./megatron/legacy/model/rms_norm.py ../megatron-lm/megatron/legacy/model/rms_norm.py 4a5,7 > import torch._dynamo > torch._dynamo.config.suppress_errors = True > 25a29 > @torch.compile(mode="max-autotune-no-cudagraphs") 28a33 > @torch.compile(mode="max-autotune-no-cudagraphs") diff -rN ./megatron/legacy/model/transformer.py ../megatron-lm/megatron/legacy/model/transformer.py 42a43,45 > import torch._dynamo > torch._dynamo.config.suppress_errors = True > 58a62,65 > try: > from flash_attn.flash_attn_triton import flash_attn_func > except ImportError: > flash_attn_func = None 135a143 > @torch.compile(mode="max-autotune-no-cudagraphs") 159c167 < --- > @torch.compile(mode="max-autotune-no-cudagraphs") 469a478,481 > # Use FlashAttention-2 when args.use_flash_attn_ck is True > args = get_args() > self.flash_attn_func = flash_attn_unpadded_func > 510a523,554 > class FlashSelfAttentionTriton(torch.nn.Module): > """Implement the scaled dot product attention with softmax. > Arguments > --------- > softmax_scale: The temperature to use for the softmax attention. > (default: 1/sqrt(d_keys) where d_keys is computed at > runtime) > attention_dropout: The dropout rate to apply to the attention > (default: 0.0) > """ > def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, > device=None, dtype=None): > super().__init__() > assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.') > assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' > self.causal = causal > self.softmax_scale = softmax_scale > self.dropout_p = attention_dropout > > def forward(self, q, k, v): > """Implements the multihead softmax attention. > Arguments > --------- > q, k, v: The tensor containing the query, key, and value. (B, S, H, D) > """ > assert q.dtype in [torch.float16, torch.bfloat16] > assert q.is_cuda > q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous() > for x in (q, k, v)] > output = flash_attn_func(q, k, v, self.causal) > output = rearrange(output, 'b s h d -> h b (s d)').contiguous() > return output 539c583 < self.use_flash_attn = args.use_flash_attn \ --- > self.use_flash_attn = (args.use_flash_attn_ck or args.use_flash_attn_triton) \ 541a586,587 > self.use_flash_attn_triton = args.use_flash_attn_triton > 543,544c589,591 < if flash_attn_unpadded_func is None: < raise ImportError('FlashAttention is not installed, please install with ' --- > if args.use_flash_attn_ck: > if flash_attn_unpadded_func is None: > raise ImportError('FlashAttention is not installed, please install with ' 545a593,595 > if args.use_flash_attn_triton: > assert flash_attn_func != None, "Cannot import FlashAttention triton " > 605c655,658 < if self.use_flash_attn: --- > # Currently FlashAttention only works with causal mask > if self.use_flash_attn_triton: > self.core_attention_flash = FlashSelfAttentionTriton(causal=True, attention_dropout=args.attention_dropout) > elif self.use_flash_attn: 713c766 < query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) --- > query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) 818c871,873 < q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() --- > if not self.use_flash_attn_triton: > query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() > #q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() 822c877 < context_layer = self.core_attention_flash(q, k, v) --- > context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) 824,825c879,881 < context_layer = self.core_attention_flash(q, k, v) < context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() --- > context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) > if not self.use_flash_attn_triton: > context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() 1176a1233,1234 > #from unsloth.kernels.rms_layernorm import fast_rms_layernorm > #norm_output = self.input_norm(hidden_states) if not args.use_fast_rms_layernorm else fast_rms_layernorm(self.input_norm, hidden_states) 1408a1467,1468 > from importlib.metadata import version > 1409a1470 > from pkg_resources import packaging 1411c1472,1473 < if core.utils.is_te_min_version("0.8.0"): --- > te_version = packaging.version.Version(version("transformer-engine")) > if te_version >= packaging.version.Version("0.8.0"): 1413c1475 < if core.utils.is_te_min_version("0.10.0"): --- > if te_version >= packaging.version.Version("0.10.0"): 1415c1477 < if core.utils.is_te_min_version("0.11.0"): --- > if te_version >= packaging.version.Version("0.11.0"): 1417a1480,1481 > del version, packaging > 1427c1491 < self.fp8_group = mpu.get_amax_reduction_group(tp_only_amax_red=config.tp_only_amax_red) --- > self.fp8_group = mpu.get_amax_reduction_group() Binary files ./megatron/legacy/model/.transformer.py.swp and ../megatron-lm/megatron/legacy/model/.transformer.py.swp differ diff -rN ./megatron/legacy/model/utils.py ../megatron-lm/megatron/legacy/model/utils.py 11a12,13 > import torch._dynamo > torch._dynamo.config.suppress_errors = True 61c63 < --- > @torch.compile(mode="max-autotune-no-cudagraphs") diff -rN ./megatron/training/arguments.py ../megatron-lm/megatron/training/arguments.py 53a54 > parser = _add_unsloth_args(parser) 74,75c75,76 < args.rank = int(os.getenv('RANK', '0')) < args.world_size = int(os.getenv("WORLD_SIZE", '1')) --- > #args.rank = int(os.getenv('RANK', '0')) > #args.world_size = int(os.getenv("WORLD_SIZE", '1')) 290d290 < args.align_param_gather = False 292,294c292,293 < print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False ' < 'since non-interleaved schedule does not support overlapping p2p communication ' < 'and aligned param AG') --- > print('WARNING: Setting args.overlap_p2p_comm to False since non-interleaved ' > 'schedule does not support overlapping p2p communication') 314,316c313,315 < if args.fp8_param_gather: < assert args.use_distributed_optimizer, \ < '--fp8-param-gather only supported with distributed optimizer' --- > if args.align_param_gather: > assert args.virtual_pipeline_model_parallel_size is not None, \ > '--align-param-gather only supported with interleaved pipeline parallelism' 539a539,540 > # FlashAttention > args.use_flash_attn = args.use_flash_attn_ck or args.use_flash_attn_triton 572a574 > args.use_dist_ckpt = False 621,630d622 < # MoE upcycling check < if args.moe_use_upcycling: < assert args.save is not None, "When using upcycling, the --save option must be specified." < if not args.no_load_optim: < args.no_load_optim = True < print('Warning: disabling --no-load-optim for upcycling.') < if not args.no_load_rng: < args.no_load_rng = True < print('Warning: disabling --no-load-rng for upcycling.') < 674,675d665 < kw_args['first_pipeline_num_layers']= args.decoder_first_pipeline_num_layers < kw_args['last_pipeline_num_layers']= args.decoder_last_pipeline_num_layers 709c699 < help='DEPRECATED. This flag is ignored. Scaling update interval for fp8', --- > help='Scaling update interval for fp8', 724,726d713 < group.add_argument('--fp8-param-gather', action='store_true', < help='Keep the compute param in fp8 (do not use any other intermediate ' < 'dtype) and perform the param all-gather in fp8.') 1127,1130d1113 < group.add_argument('--use-pytorch-profiler', action='store_true', < help='Use the built-in pytorch profiler. ' < 'Useful if you wish to view profiles in tensorboard.', < dest='use_pytorch_profiler') 1219c1202 < group.add_argument('--use-flash-attn', action='store_true', --- > group.add_argument('--use-flash-attn-ck', action='store_true', 1221a1205,1206 > group.add_argument('--use-flash-attn-triton', action='store_true', > help='use FlashAttention implementation of attention using Triton.') 1390,1394d1374 < group.add_argument('--non-persistent-local-ckpt-dir', type=str, default=None, < help='Directory containing local non-persistent model checkpoints.') < group.add_argument('--non-persistent-local-ckpt-algo', type=str, default='fully_parallel', < choices=['fully_parallel', 'atomic'], < help='Algorithm for local non-persistent checkpointing.') 1518,1525d1497 < group.add_argument('--decoder-first-pipeline-num-layers', < type=int, default=None, < help=('The number of transformer layers on the first pipeline stage of the decoder. ' < 'Default None is even split of transformer layers across all pipeline stages')) < group.add_argument('--decoder-last-pipeline-num-layers', < type=int, default=None, < help=('The number of transformer layers on the last pipeline stage of the decoder. ' < 'Default None is even split of transformer layers across all pipeline stages')) 1560,1563c1532,1534 < group.add_argument('--no-align-param-gather', action='store_false', < help='If not set, all PP stages will launch param all-gathers simultaneously. ' < 'Otherwise, each PP stage will independently launch as needed.', < dest='align_param_gather') --- > group.add_argument('--align-param-gather', action='store_true', default=False, > help='If set, all PP stages will launch param all-gathers simultaneously. ' > 'Otherwise, each PP stage will independently launch as needed.') 1571c1542,1544 < group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')), --- > # group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')), > # help='local rank passed from distributed launcher.') > group.add_argument('--local_rank', type=int, default=None, 1596a1570,1575 > group.add_argument('--rank', default=-1, type=int, > help='node rank for distributed training') > group.add_argument('--world_size', type=int, default=8, > help='number of nodes for distributed training') > group.add_argument('--dist_url', > help='Which master node url for distributed training.') 1899,1901d1877 < group.add_argument('--moe-use-upcycling', action='store_true', < help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. ' < 'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.') 1928a1905,1914 > return parser > > def _add_unsloth_args(parser): > group = parser.add_argument_group(title='unsloth') > > group.add_argument('--use-fast-cross-entropy-loss', action='store_true', > help='Use fast_cross_entropy_loss of unsloth more faster in calculating loss') > group.add_argument('--use-fast-rms-layernorm', action='store_true', > help='Use fast_rms_layernorm of unsloth more faster in Layer Normalization') > diff -rN ./megatron/training/checkpointing.py ../megatron-lm/megatron/training/checkpointing.py 5d4 < from enum import Enum, auto 22,25d20 < from megatron.core.dist_checkpointing.state_dict_transformation import ( < prepare_state_dict_for_save, < recreate_state_dict_after_load, < ) 29d23 < from megatron.core.utils import is_float8tensor 36a31,32 > import pdb > 299,302d294 < class CheckpointType(Enum): < LEGACY = auto() < LOCAL = auto() < GLOBAL = auto() 333c325 < ckpt_type = CheckpointType.GLOBAL if args.use_dist_ckpt else CheckpointType.LEGACY --- > use_dist_ckpt = args.use_dist_ckpt or non_persistent_ckpt 336,353c328,334 < if args.non_persistent_ckpt_type == 'global': < ckpt_type = CheckpointType.GLOBAL < save_dir = ( < args.non_persistent_global_ckpt_dir < if args.non_persistent_global_ckpt_dir < else os.path.join(save_dir, _NON_PERSISTENT_CKPT_SUBDIR) < ) < # TODO Can we ensure the previous checkpoint is saved? We don't want to allow two saves in parallel. < cleanup_old_non_persistent_checkpoint( < save_dir, leave_ckpt_num=1, do_async=args.async_save < ) < elif args.non_persistent_ckpt_type == 'local': < raise RuntimeError('LocalCheckpointManagers are not yet integrated') < ckpt_type = CheckpointType.LOCAL < save_dir = checkpointing_context['local_checkpoint_manager'].local_ckpt_dir < else: < assert False, 'Please use local or global non-persistent checkpoints' \ < f'(got: {args.non_persistent_ckpt_type})' --- > save_dir = ( > args.non_persistent_global_ckpt_dir > if args.non_persistent_global_ckpt_dir > else os.path.join(save_dir, _NON_PERSISTENT_CKPT_SUBDIR) > ) > # TODO Can we ensure the previous checkpoint is saved? We don't want to allow two saves in parallel. > cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=args.async_save) 355c336 < ckpt_format = args.ckpt_format if ckpt_type == CheckpointType.GLOBAL else 'torch' --- > ckpt_format = args.ckpt_format if use_dist_ckpt else 'torch' 360c341 < rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY) --- > rng_state = get_rng_state(use_dist_ckpt) 363d343 < return_base_dir = (ckpt_type != CheckpointType.LEGACY) 365c345 < tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir) --- > tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=use_dist_ckpt) 371,376c351 < if ( < args.use_distributed_optimizer < and not args.no_save_optim < and optimizer is not None < and ckpt_type == CheckpointType.LEGACY < ): --- > if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None and not use_dist_ckpt: 384c359 < if ckpt_type == CheckpointType.LEGACY: --- > if not args.use_dist_ckpt: 386c361 < elif ckpt_type == CheckpointType.GLOBAL and args.ckpt_format != 'torch_dist': --- > elif args.ckpt_format != 'torch_dist': 394c369 < or ckpt_type != CheckpointType.LEGACY: --- > or use_dist_ckpt: 396c371 < if ckpt_type != CheckpointType.LEGACY and args.use_distributed_optimizer: --- > if use_dist_ckpt and args.use_distributed_optimizer: 401,410c376,377 < state_dict = generate_state_dict( < args, < model, < optimizer, < opt_param_scheduler, < rng_state, < ckpt_type != CheckpointType.LEGACY, < iteration, < optim_sd_kwargs=optim_sd_kwargs, < ) --- > state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state, > use_dist_ckpt, iteration, optim_sd_kwargs=optim_sd_kwargs) 415c382,386 < if ckpt_type == CheckpointType.GLOBAL: --- > if use_dist_ckpt: > if non_persistent_ckpt and args.non_persistent_ckpt_type != 'global': > raise NotImplementedError( > 'Local and online checkpoints are not yet supported, please use global non-persistent checkpoints' > ) 447,458c418,420 < if ckpt_type == CheckpointType.LOCAL: < state_dict_for_save = prepare_state_dict_for_save( < state_dict, algo=args.non_persistent_local_ckpt_algo < ) < async_save_request = checkpointing_context['local_checkpoint_manager'].save( < state_dict_for_save, iteration, is_async=bool(args.async_save) < ) < else: < assert ckpt_type == CheckpointType.LEGACY < # Save. < ensure_directory_exists(checkpoint_name) < torch.save(state_dict, checkpoint_name) --- > # Save. > ensure_directory_exists(checkpoint_name) > torch.save(state_dict, checkpoint_name) 468c430 < or torch.distributed.get_rank() == 0: --- > or torch.distributed.get_rank() == 0: 471,486c433,440 < if ckpt_type == CheckpointType.LOCAL: < def iter_finalize_fn(): < print_rank_0(' successfully saved local checkpoint from iteration {:7d}' < .format(iteration)) < if args.log_progress and args.async_save: < append_to_progress_log(f'Saved async local checkpoint\tIteration: {iteration}', < barrier=False) < else: < def iter_finalize_fn(): < with open(tracker_filename, 'w') as f: < f.write(str(iteration)) < print_rank_0(' successfully saved checkpoint from iteration {:7d} to {}' < .format(iteration, args.save)) < if args.log_progress and args.async_save: < append_to_progress_log(f'Saved async checkpoint\tIteration: {iteration}', < barrier=False) --- > def iter_finalize_fn(): > with open(tracker_filename, 'w') as f: > f.write(str(iteration)) > print_rank_0(' successfully saved checkpoint from iteration {:7d} to {}' > .format(iteration, args.save)) > if args.log_progress and args.async_save: > append_to_progress_log(f'Saved async checkpoint\tIteration: {iteration}', > barrier=False) 508c462 < .format(iteration, save_dir)) --- > .format(iteration, args.save)) 598a553 > #pdb.set_trace() 602a558 > #print("state_dict['model'] are:",state_dict['model']) 621a578 > #print("++++++ state_dict are:",state_dict) 691c648 < print_rank_0(" successfully fixed query-key-values ordering for" --- > print_rank_0(" succesfully fixed query-key-values ordering for" 695,699c652,654 < def _get_non_persistent_iteration(non_persistent_global_dir, args, checkpointing_context=None): < if args.non_persistent_ckpt_type is None: < return -1 < elif args.non_persistent_ckpt_type == "global": < tracker_filename = get_checkpoint_tracker_filename(non_persistent_global_dir) --- > def _get_non_persistent_iteration(non_persistent_dir, args): > if args.non_persistent_ckpt_type == "global": > tracker_filename = get_checkpoint_tracker_filename(non_persistent_dir) 709,711c664,665 < elif args.non_persistent_ckpt_type == "local": < raise RuntimeError('LocalCheckpointManagers are not yet integrated') < return checkpointing_context['local_checkpoint_manager'].get_latest_checkpoint_iteration() --- > elif args.non_persistent_ckpt_type is None: > return -1 713,714c667,669 < assert False, 'Please use local or global non-persistent checkpoints' \ < f'(got: {args.non_persistent_ckpt_type})' --- > raise NotImplementedError( > 'Local and online checkpoints are not yet supported, please use global non-persistent checkpoints' > ) 718,723c673 < non_persistent_global_dir, < args, < rank0, < sharded_state_dict, < non_persistent_iteration, < checkpointing_context=None, --- > non_persistent_dir, args, rank0, sharded_state_dict, non_persistent_iteration 729a680,685 > checkpoint_name = get_checkpoint_name( > non_persistent_dir, non_persistent_iteration, False, return_base_dir=True > ) > # "non_persistent" checkpoint is only used for distributed checkpoints > # Skipping the assert to avoid unnecessary disk access. > # assert dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name) 735c691 < non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False --- > non_persistent_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False 737,747d692 < elif args.non_persistent_ckpt_type == "local": < raise RuntimeError('LocalCheckpointManagers are not yet integrated') < intermediate_state_dict, checkpoint_name = checkpointing_context[ < 'local_checkpoint_manager' < ].load() < state_dict = recreate_state_dict_after_load( < sharded_state_dict, < intermediate_state_dict, < algo=args.non_persistent_local_ckpt_algo, < ) < return state_dict, checkpoint_name, False, CheckpointType.LOCAL 749,750c694,696 < assert False, 'Please use local or global non-persistent checkpoints' \ < f'(got: {args.non_persistent_ckpt_type})' --- > raise NotImplementedError( > 'Local and online checkpoints are not yet supported, please use global non-persistent checkpoints' > ) 760c706 < return state_dict, checkpoint_name, release, CheckpointType.GLOBAL --- > return state_dict, checkpoint_name, release 779c725 < return state_dict, checkpoint_name, release, CheckpointType.GLOBAL --- > return state_dict, checkpoint_name, release 783,787c729 < load_dir, < args, < rank0=False, < sharded_state_dict=None, < checkpointing_context=None, --- > load_dir, args, rank0=False, sharded_state_dict=None 794c736 < non_persistent_global_dir = ( --- > non_persistent_dir = ( 796c738 < if args.non_persistent_global_ckpt_dir or load_dir is None --- > if args.non_persistent_global_ckpt_dir 799,807c741,746 < non_persistent_iteration = _get_non_persistent_iteration( < non_persistent_global_dir, args, checkpointing_context < ) < iteration, release = -1, False < tracker_filename = 'because load directory is not defined' < if load_dir is not None: < tracker_filename = get_checkpoint_tracker_filename(load_dir) < if os.path.isfile(tracker_filename): < iteration, release = read_metadata(tracker_filename) --- > non_persistent_iteration = _get_non_persistent_iteration(non_persistent_dir, args) > tracker_filename = get_checkpoint_tracker_filename(load_dir) > if os.path.isfile(tracker_filename): > iteration, release = read_metadata(tracker_filename) > else: > iteration, release = -1, False 811,816c750 < non_persistent_global_dir, < args, < rank0, < sharded_state_dict, < non_persistent_iteration, < checkpointing_context, --- > non_persistent_dir, args, rank0, sharded_state_dict, non_persistent_iteration 834c768 < return None, "", False, None --- > return None, "", False 852a787 > 880c815 < return state_dict, checkpoint_name, release, CheckpointType.LEGACY --- > return state_dict, checkpoint_name, release 883,885c818 < def load_args_from_checkpoint( < args, load_arg='load', checkpointing_context=None < ): --- > def load_args_from_checkpoint(args, load_arg='load'): 904,908c837,838 < state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( < load_dir, < args, < rank0=True, < checkpointing_context=checkpointing_context, --- > state_dict, checkpoint_name, release = _load_base_checkpoint( > load_dir, args, rank0=True 978,991d907 < def fix_fp8_params_lose_precision_when_loading_dist_ckpt(state_dict): < """ < When "--fp8-param-gather" and "--use-dist-ckpt" are both enabled, the state dict read from < dist-checkpoint loses precision (the weights read from checkpoint go through the process of < bf16/fp16 -> fp8 -> bf16/fp16). This function is implemented to solve this problem. < When "--fp8-param-gather" is disabled, this function doesn't modify anything. < """ < for key in state_dict.keys(): < if key.startswith('model'): < for _, sharded_tensor in state_dict[key].items(): < if is_float8tensor(sharded_tensor.data): < sharded_tensor.data = sharded_tensor.data.from_float8().cpu() < < 993c909 < ft_client=None, checkpointing_context=None): --- > ft_client=None): 1022,1026c938,939 < state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( < load_dir, < args, < rank0=True, < checkpointing_context=checkpointing_context, --- > state_dict, checkpoint_name, release = _load_base_checkpoint( > load_dir, args, rank0=True 1027a941 > 1033,1036c947,948 < is_dist_ckpt = ( < ckpt_type == CheckpointType.LOCAL < or dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name) < ) --- > > is_dist_ckpt = dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name) 1086,1087d997 < # When "--fp8-param-gather" is disabled, this function doesn't modify anything. < fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict']) 1089,1091c999,1000 < state_dict, checkpoint_name, release, ckpt_type = _load_base_checkpoint( < load_dir, args, rank0=False, checkpointing_context=checkpointing_context, < **load_kwargs --- > state_dict, checkpoint_name, release = _load_base_checkpoint( > load_dir, args, rank0=False, **load_kwargs 1142,1146c1051 < if ckpt_type == CheckpointType.LOCAL: < raise NotImplementedError('Local checkpointing does not support model opt') < if not args.use_dist_ckpt: < restore_modelopt_state(model, state_dict) < else: --- > if args.use_dist_ckpt: 1147a1053,1054 > else: > restore_modelopt_state(model, state_dict) diff -rN ./megatron/training/ft_integration.py ../megatron-lm/megatron/training/ft_integration.py 92c92 < from nvidia_resiliency_ext.fault_tolerance import RankMonitorClient --- > from fault_tolerance import RankMonitorClient diff -rN ./megatron/training/initialize.py ../megatron-lm/megatron/training/initialize.py 6a7,8 > import packaging > import packaging.version 173c175 < fused_kernels.load(args) --- > #fused_kernels.load(args) 177c179 < fused_kernels.load(args) --- > #fused_kernels.load(args) 243,244c245,254 < torch.cuda.set_device(args.local_rank) < device_id = torch.device(f'cuda:{args.local_rank}') --- > #torch.cuda.set_device(args.local_rank) > #device_id = torch.device(f'cuda:{args.local_rank}') > device_id = args.rank % device_count > if args.local_rank is not None: > assert ( > args.local_rank == device_id > ), "expected local-rank to be the same as rank % device-count." > else: > args.local_rank = device_id > torch.cuda.set_device(device_id) 249,254c259,273 < init_process_group_kwargs = { < 'backend' : args.distributed_backend, < 'world_size': args.world_size, < 'rank': args.rank, < 'timeout': timedelta(minutes=args.distributed_timeout_minutes), < } --- > torch.distributed.init_process_group( > backend=args.distributed_backend, > world_size=args.world_size, > rank=args.rank, > init_method=args.dist_url, > timeout=timedelta(minutes=args.distributed_timeout_minutes), > ) > #init_process_group_kwargs = { > # 'backend' : args.distributed_backend, > # 'world_size': args.world_size, > # 'rank': args.rank, > # 'timeout': timedelta(minutes=args.distributed_timeout_minutes), > #} > #if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"): > # init_process_group_kwargs['device_id'] = device_id 256c275 < torch.distributed.init_process_group(**init_process_group_kwargs) --- > #torch.distributed.init_process_group(**init_process_group_kwargs) 337c356 < torch._C._jit_set_nvfuser_enabled(True) --- > torch._C._jit_set_nvfuser_enabled(False) #True Binary files ./megatron/training/__pycache__/activations.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/activations.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/arguments.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/arguments.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/async_utils.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/async_utils.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/checkpointing.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/checkpointing.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/dist_signal_handler.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/dist_signal_handler.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/ft_integration.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/ft_integration.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/global_vars.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/global_vars.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/initialize.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/initialize.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/log_handler.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/log_handler.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/one_logger_utils.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/one_logger_utils.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/theoretical_memory_usage.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/theoretical_memory_usage.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/training.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/training.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/utils.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/utils.cpython-310.pyc differ Binary files ./megatron/training/__pycache__/yaml_arguments.cpython-310.pyc and ../megatron-lm/megatron/training/__pycache__/yaml_arguments.cpython-310.pyc differ Binary files ./megatron/training/tokenizer/__pycache__/bert_tokenization.cpython-310.pyc and ../megatron-lm/megatron/training/tokenizer/__pycache__/bert_tokenization.cpython-310.pyc differ Binary files ./megatron/training/tokenizer/__pycache__/gpt2_tokenization.cpython-310.pyc and ../megatron-lm/megatron/training/tokenizer/__pycache__/gpt2_tokenization.cpython-310.pyc differ Binary files ./megatron/training/tokenizer/__pycache__/__init__.cpython-310.pyc and ../megatron-lm/megatron/training/tokenizer/__pycache__/__init__.cpython-310.pyc differ Binary files ./megatron/training/tokenizer/__pycache__/tokenizer.cpython-310.pyc and ../megatron-lm/megatron/training/tokenizer/__pycache__/tokenizer.cpython-310.pyc differ diff -rN ./megatron/training/training.py ../megatron-lm/megatron/training/training.py 23,28c23 < from megatron.core.utils import ( < check_param_hashes_across_dp_replicas, < get_model_config, < StragglerDetector, < is_float8tensor, < ) --- > from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config, StragglerDetector 31d25 < from megatron.training.checkpointing import checkpoint_exists 43d36 < from megatron.core.transformer.moe import upcycling_utils 81c74 < stimer = StragglerDetector() --- > import pdb 82a76 > stimer = StragglerDetector() 89c83 < --- > 275,285d268 < # Context used for persisting some state between checkpoint saves. < if args.non_persistent_ckpt_type == 'local': < raise RuntimeError('LocalCheckpointManagers are not yet integrated') < checkpointing_context = { < 'local_checkpoint_manager': BasicLocalCheckpointManager( < args.non_persistent_local_ckpt_dir < ) < } < else: < checkpointing_context = {} < 290c273 < model_provider, model_type, checkpointing_context=checkpointing_context) --- > model_provider, model_type) 325a309,311 > # Context used for persisting some state between checkpoint saves. > checkpointing_context = {} > 505,519d490 < # The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's < # Float8Tensor, which will write an unwanted value (amax calculated from the current fp8 < # param) to its amax_history. The following logic will correct the amax_history back. < for model_module in model: < for param in model_module.parameters(): < if is_float8tensor(param) and param._fp8_meta is not None: < fp8_meta = param._fp8_meta['scaling_fwd'] < fp8_meta_index = param._fp8_meta_index < if hasattr(param, 'get_high_precision_init_val'): < fp8_meta.amax_history[0][fp8_meta_index].copy_( < param.get_high_precision_init_val().abs().max() < ) < else: < fp8_meta.amax_history[0][fp8_meta_index] = 0 < 522,532c493,499 < < kwargs = {} < for f in dataclasses.fields(DistributedDataParallelConfig): < if hasattr(args, f.name): < kwargs[f.name] = getattr(args, f.name) < kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 < kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad < kwargs['bucket_size'] = args.ddp_bucket_size < kwargs['average_in_collective'] = args.ddp_average_in_collective < ddp_config = DistributedDataParallelConfig(**kwargs) < --- > ddp_config = DistributedDataParallelConfig( > grad_reduce_in_fp32=args.accumulate_allreduce_grads_in_fp32, > overlap_grad_reduce=args.overlap_grad_reduce, > use_distributed_optimizer=args.use_distributed_optimizer, > check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad, > bucket_size=args.ddp_bucket_size, > average_in_collective=args.ddp_average_in_collective) 610,611c577 < lr_mult=1.0, < checkpointing_context=None): --- > lr_mult=1.0): 630,659c596 < if args.moe_use_upcycling: < torch.distributed.barrier() < assert not checkpoint_exists( < args.save < ), ("The upcycling destination directory already exists. " < "Please check if --moe-use-upcycling is mistakenly enabled. " < "Upcycling should only be set for the first run when converting the dense model. " < "All subsequent runs should remove this flag. ") < num_experts = args.num_experts < args.num_experts = None < expert_model_parallel_size = args.expert_model_parallel_size < args.expert_model_parallel_size = 1 < dense_model_for_upcycling = get_model(model_provider_func, model_type) < args.num_experts = num_experts < args.expert_model_parallel_size = expert_model_parallel_size < _, args.num_floating_point_operations_so_far = upcycling_utils.load_and_upcycle_model( < load_checkpoint, < unwrapped_model, < dense_model_for_upcycling, < load_kwargs = {'model': dense_model_for_upcycling, 'optimizer': None, 'opt_param_scheduler': None} < ) < args.iteration = 1 < save_checkpoint(args.iteration, model, None, None, args.num_floating_point_operations_so_far) < torch.distributed.barrier() < del dense_model_for_upcycling < if (args.fp16 or args.bf16) and optimizer is not None: < optimizer.reload_model_params() < print_rank_0(f'Upcycled checkpoint saved to {args.save}') < < if (args.load is not None or args.pretrained_checkpoint is not None) and not args.moe_use_upcycling: --- > if args.load is not None or args.pretrained_checkpoint is not None: 667c604,605 < ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context) --- > ft_client=ft_integration.get_rank_monitor_client()) > 692c630 < --- > 702a641 > 1060a1000 > 1067a1008 > 1139c1080,1081 < config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] --- > config.param_sync_func = [functools.partial(optimizer.start_param_sync, model_index) > for model_index in range(len(model))] 1192,1203d1133 < if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler: < prof = torch.profiler.profile( < schedule=torch.profiler.schedule( < wait=max(args.profile_step_start-1, 0), < warmup=1 if args.profile_step_start > 0 else 0, < active=args.profile_step_end-args.profile_step_start, < repeat=1), < on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), < record_shapes=True, < with_stack=True) < prof.start() < 1205,1210c1135,1139 < if args.profile and torch.distributed.get_rank() in args.profile_ranks: < if args.use_pytorch_profiler: < prof.step() < elif iteration == args.profile_step_start: < torch.cuda.cudart().cudaProfilerStart() < torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() --- > if args.profile and \ > iteration == args.profile_step_start and \ > torch.distributed.get_rank() in args.profile_ranks: > torch.cuda.cudart().cudaProfilerStart() > torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() 1217a1147 > #pdb.set_trace() 1305c1235 < assert check_param_hashes_across_dp_replicas(model, cross_check=True), \ --- > assert check_param_hashes_across_dp_replicas(model), \ 1378d1307 < checkpointing_context, 1415,1420c1344,1346 < iteration == args.profile_step_end and \ < torch.distributed.get_rank() in args.profile_ranks: < if args.use_pytorch_profiler: < prof.stop() < else: < torch.cuda.cudart().cudaProfilerStop() --- > iteration == args.profile_step_end and \ > torch.distributed.get_rank() in args.profile_ranks: > torch.cuda.cudart().cudaProfilerStop()