Commit 051f58f1 authored by liangjing's avatar liangjing
Browse files

v1

parent 0024a5c6
Pipeline #829 passed with stage
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
import logging
logger = logging.getLogger(__name__)
try:
import tensorstore
import zarr
from .tensorstore import _import_trigger
from .zarr import _import_trigger
except ImportError:
logger.warning('Zarr-based strategies will not be registered because of missing packages')
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional
from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict
class StrategyAction(Enum):
LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
default_strategies = defaultdict(dict)
def get_default_strategy(action: StrategyAction, backend: str, version: int):
try:
return default_strategies[action.value][(backend, version)]
except KeyError as e:
raise CheckpointingException(
f'Cannot find default strategy for: {(action, backend, version)}'
) from e
class LoadStrategyBase(ABC):
@abstractmethod
def check_backend_compatibility(self, loaded_version):
raise NotImplementedError
@abstractmethod
def check_version_compatibility(self, loaded_version):
raise NotImplementedError
class SaveStrategyBase(ABC):
def __init__(self, backend: str, version: int):
self.backend = backend
self.version = version
class LoadCommonStrategy(LoadStrategyBase):
@abstractmethod
def load(self, checkpoint_dir: Path):
raise NotImplementedError
class LoadShardedStrategy(LoadStrategyBase):
@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
raise NotImplementedError
class SaveCommonStrategy(SaveStrategyBase):
@abstractmethod
def save(self, common_state_dict: StateDict, checkpoint_dir: Path):
raise NotImplementedError
class SaveShardedStrategy(SaveStrategyBase):
@abstractmethod
def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
raise NotImplementedError
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using TensorStore to load and save Zarr arrays. """
from functools import partial
from itertools import starmap
from pathlib import Path
import tensorstore as ts
import torch
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy, StrategyAction, default_strategies
from .zarr import postprocess_numpy_array
_import_trigger = None
class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if torch.distributed.get_rank() == 0:
print(f'Loading distributed checkpoint with {self.__class__.__name__}')
if self.load_directly_on_device:
print(f'Loading distributed checkpoint directly on the GPU')
load_fn = partial(
_load_from_array,
checkpoint_dir=checkpoint_dir,
load_directly_on_device=self.load_directly_on_device,
)
dict_list_map_inplace(load_fn, sharded_state_dict)
return sharded_state_dict
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def merge_global_slice_with_shape(global_slice, actual_shape, key):
def _merge_slice(dim_slice, dim_size):
if isinstance(dim_slice, slice):
assert (
dim_slice.start < dim_size
), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})'
if dim_slice.stop > dim_size:
dim_slice = slice(dim_slice.start, dim_size, dim_slice.step)
return dim_slice
assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key)
return tuple(starmap(_merge_slice, zip(global_slice, actual_shape)))
def _load_from_array(
sharded_tensor: ShardedTensor,
checkpoint_dir: Path,
load_directly_on_device: bool = False,
apply_flattened_range: bool = True,
):
x = _load_regular_chunk(sharded_tensor, checkpoint_dir)
ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range)
if load_directly_on_device:
sharded_tensor.data.data.copy_(ten)
return sharded_tensor.data
else:
return ten
def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {
'driver': 'file',
'path': str(checkpoint_dir / sharded_tensor.key),
}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
except Exception as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} could not be loaded. Error: {e}'
) from e
if sharded_tensor.global_shape == arr.shape:
x = (
arr[sharded_tensor.global_slice()].read().result()
) # flattened tensors loading is delayed
elif sharded_tensor.allow_shape_mismatch:
global_slice = merge_global_slice_with_shape(
sharded_tensor.global_slice(), arr.shape, sharded_tensor.key
)
x = arr[global_slice].read().result() # flattened tensors loading is delayed
else:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
return x
default_strategies[StrategyAction.LOAD_SHARDED.value][
('zarr', 1)
] = TensorStoreLoadShardedStrategy()
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial, wraps
from itertools import chain
from logging import DEBUG, INFO, StreamHandler, getLogger
from operator import attrgetter, itemgetter
from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict
from .base import LoadShardedStrategy
from .tensorstore import _load_from_array
from .zarr import flatten_range
_import_trigger = None
timers = defaultdict(list)
logger = getLogger(__name__)
def timed(verbose=True):
def timed_dec(fn):
name = fn.__name__
@wraps(fn)
def wrapped(*args, **kwargs):
if verbose:
logger.debug(f'{name} init')
start = time.time()
ret = fn(*args, **kwargs)
took = time.time() - start
if verbose:
logger.debug(f'{name} took {took}s')
timers[name].append(took)
return ret
return wrapped
return timed_dec
@dataclass
class _ShardedTensorMetadata:
global_rank: int
sharded_tensor_no_data: ShardedTensor
dist_group_rank: Tuple[int] # id of distributed group
dist_group_ranks: Tuple[int] # id of distributed group
data_size: Optional[int] = None # bytes
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
return (
sharded_tensor.key,
sharded_tensor.global_offset,
)
class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
""" Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
a) on one of the ranks load it from storage to CPU and move to CUDA
b) allocate CUDA tensor on other ranks
c) broadcast within DP group
d) copy tensor content to the model param location
e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
a) loading from storage to numpy
b) moving CPU tensors to CUDA
c) broadcast
"""
def __init__(self, data_parallel_group, cpu_transfer=True):
super().__init__()
self.cpu_transfer = cpu_transfer
self.data_parallel_group_orig = data_parallel_group
self.data_parallel_group = None if cpu_transfer else data_parallel_group
self.dp_group_ranks = tuple(
sorted(torch.distributed.get_process_group_ranks(data_parallel_group))
)
self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig)
self.global_rank = torch.distributed.get_rank()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
self.summarize_load_times()
return sharded_state_dict
def summarize_load_times(self):
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
for key, times in sorted(timers.items()):
times_sum = sum(times)
max_times = torch.tensor([times_sum], device='cuda')
avg_times = torch.tensor([times_sum], device='cuda')
torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX)
torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM)
avg_times /= torch.distributed.get_world_size()
if torch.distributed.get_rank() == 0:
logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}')
@timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array(
ten_meta.sharded_tensor_no_data,
checkpoint_dir,
load_directly_on_device=False,
apply_flattened_range=False,
)
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE')
return ret
@timed()
def maybe_init_gloo_group(self):
if not self.cpu_transfer:
return
all_groups = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
all_groups = set(tuple(sorted(gr)) for gr in all_groups)
for group_ranks in sorted(all_groups):
gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
if self.global_rank in group_ranks:
self.data_parallel_group = gloo_pg
assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
@timed()
def _build_load_plan(
self, sharded_state_dict: ShardedStateDict
) -> List[_ShardedTensorMetadata]:
local_meta = [
_ShardedTensorMetadata(
self.global_rank,
sharded_ten.without_data(),
self.dp_group_rank,
self.dp_group_ranks,
)
for sharded_ten in nested_values(sharded_state_dict)
]
all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group)
torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group)
all_meta = list(chain.from_iterable(all_meta))
all_tensors_sorted = self.deduplicate_chunks(all_meta)
return all_tensors_sorted
@timed()
def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]):
""" Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas = map_reduce(
ten_metas,
key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data),
reduce_fn=partial(min, key=attrgetter('dist_group_rank')),
)
all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items())))
return all_metas_sorted
@timed()
def _exchange_loaded_tensors(
self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir
):
logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}')
for ten_meta in ten_metas:
src_rank = torch.distributed.get_global_rank(
self.data_parallel_group, ten_meta.dist_group_rank
)
if self.dp_group_rank == ten_meta.dist_group_rank:
exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta)
if not self.cpu_transfer:
exchange_tensor = exchange_tensor.cuda()
else:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor = torch.empty(
ten_meta.sharded_tensor_no_data.local_shape,
device='cpu' if self.cpu_transfer else 'cuda',
dtype=ten_meta.sharded_tensor_no_data.dtype,
)
logger.debug(
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
)
torch.distributed.broadcast(
exchange_tensor, group=self.data_parallel_group, src=src_rank
)
self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict)
logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done')
# free buffer memory
exchange_tensor = None
@timed(verbose=False)
def _distribute_data_to_state_dict(
self,
ten_meta: _ShardedTensorMetadata,
loaded_ten: torch.Tensor,
sharded_state_dict: ShardedStateDict,
):
tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data)
def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]):
if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key:
# already filled-in or key not matching
return t
sharded_tensor: ShardedTensor = t
x = loaded_ten
if sharded_tensor.flattened_range is not None:
x = flatten_range(sharded_tensor, x)
# Reuse existing buffer
sharded_tensor.data.data.copy_(x)
return sharded_tensor.data
dict_list_map_inplace(_fill_in_data, sharded_state_dict)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import os
from functools import partial
from pathlib import Path
from typing import List
import numpy as np
import torch
import zarr
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()}
try:
import tensorstore
HAS_BFLOAT16 = True
numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16
torch_to_numpy_dtype_dict[torch.bfloat16] = np.dtype('bfloat16')
except ImportError:
HAS_BFLOAT16 = False
_import_trigger = None
class ZarrSaveShardedStrategy(SaveShardedStrategy):
def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)
for ten, arr in zip(sharded_tensors, arrays):
_save_to_existing_array(ten, arr)
torch.distributed.barrier()
def _create_or_open_zarr_arrays(
sharded_tensors: List[ShardedTensor], checkpoint_dir: Path
) -> List[zarr.Array]:
arrays = []
for ten in sharded_tensors:
if _should_create_array(ten):
_create_zarr_array(ten, checkpoint_dir)
# TODO: maybe reuse the opened arrays
torch.distributed.barrier()
for ten in sharded_tensors:
# if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}:
# continue
open_kwargs = {}
if ten.flattened_range is not None:
open_kwargs['synchronizer'] = zarr.ProcessSynchronizer(
str(checkpoint_dir / f'{ten.key}.sync')
)
arr = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs)
arrays.append(arr)
return arrays
def _should_create_array(ten: ShardedTensor):
return (
is_main_replica(ten.replica_id)
and set(ten.global_offset) == {0}
and (ten.flattened_range is None or ten.flattened_range.start == 0)
)
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: zarr.Array):
if not is_main_replica(sharded_tensor.replica_id):
return
x = sharded_tensor.data
x = x.detach().cpu()
torch.cuda.synchronize()
if x.dtype == torch.bfloat16:
x = x.float()
x = x.numpy()
x = x.astype('bfloat16')
else:
x = x.numpy()
if sharded_tensor.flattened_range is None:
arr[sharded_tensor.global_slice()] = x
else:
arr.set_coordinate_selection(sharded_tensor.global_coordinates(), x)
def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
np_dtype = torch_to_numpy_dtype_dict[sharded_tensor.dtype]
try:
arr = zarr.create(
sharded_tensor.global_shape,
dtype=np_dtype,
store=checkpoint_dir / sharded_tensor.key,
chunks=sharded_tensor.max_allowed_chunks(),
compressor=None,
fill_value=None,
write_empty_chunks=True,
)
except zarr.errors.ContainsArrayError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} already exists'
) from e
if HAS_BFLOAT16 and np_dtype == np.dtype('bfloat16'):
arr._dtype = np_dtype
zarray = arr.store['.zarray']
arr.store['.zarray'] = zarray.replace(b'<V2', b'bfloat16')
return arr
class ZarrLoadShardedStrategy(LoadShardedStrategy):
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
dict_list_map_inplace(
partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict
)
return sharded_state_dict
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
try:
arr = zarr.open(checkpoint_dir / sharded_tensor.key, 'r')
except zarr.errors.PathNotFoundError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} not found'
) from e
if not sharded_tensor.allow_shape_mismatch and sharded_tensor.global_shape != arr.shape:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
x = arr[sharded_tensor.global_slice()] # flattened tensors loading is delayed
return postprocess_numpy_array(x, sharded_tensor)
def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True):
x = loaded_array
if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'):
x = x.astype(np.dtype('float32'))
x = torch.from_numpy(x)
x = x.bfloat16()
else:
x = torch.from_numpy(x)
# TODO: consider some other consistency checks
if x.shape != sharded_tensor.local_shape:
if sharded_tensor.allow_shape_mismatch:
x = pad_to_expected_shape(x, sharded_tensor)
else:
_msg = (
f'Local shape mismatch for loaded ({x.shape})'
f' and expected ({sharded_tensor.local_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
if apply_flattened_range and sharded_tensor.flattened_range is not None:
x = flatten_range(sharded_tensor, x)
# TODO: consider cuda() tensors support
return x
def flatten_range(sharded_tensor, x):
return x.flatten()[sharded_tensor.flattened_range]
def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
pad_args = []
assert len(x.shape) == len(expected_sharded_ten.local_shape)
# Reversed iteration order because F.pad expects so
for x_sh, exp_sh, axis_fragm in reversed(
list(
zip(x.shape, expected_sharded_ten.local_shape, expected_sharded_ten.axis_fragmentations)
)
):
if x_sh == exp_sh:
pad_args.extend((0, 0))
elif x_sh > exp_sh:
assert (
False
), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}'
else:
pad_args.extend((0, exp_sh - x_sh))
# TODO: behavior control with envvar is for testing purposes only, remove it
if not int(os.environ.get('DIST_CKPT_PAD_REPLICATE', 0)):
return torch.nn.functional.pad(x, pad_args)
# unsqueeze and squeeze to get shapes supported by cudnn
print(f'Replicating last row for {expected_sharded_ten.key}')
if x.dtype == torch.bfloat16:
return (
torch.nn.functional.pad(x.float().unsqueeze(0), pad_args, mode='replicate')
.squeeze(0)
.bfloat16()
)
return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0)
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy(
'zarr', 1
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
LocalNonpersitentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
def extract_sharded_tensors(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor))
def extract_sharded_tensors_and_factories(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory))
)
def extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(
sharded_state_dict,
lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)),
)
def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str):
def add_prefix(t):
if isinstance(t, ShardedTensor):
t.key = f'{prefix}.{t.key}'
return t
dict_list_map_inplace(add_prefix, sharded_state_dict)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
retro_encoder = 3
retro_decoder = 4
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Tuple
import torch
def _bias_dropout_add_func(x, bias, residual, prob, training):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
) -> torch.Tensor:
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
) -> torch.Tensor:
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, False)
def get_bias_dropout_add(training, fused):
def unfused_bias_dropout_add(x_with_bias, residual, prob):
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, training)
if fused:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if training:
return bias_dropout_add_fused_train
else:
return bias_dropout_add_fused_inference
else:
return unfused_bias_dropout_add
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return ff * g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except:
HAVE_PERSIST_LAYER_NORM = False
try:
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM = True
except:
HAVE_FUSED_LAYER_NORM = False
class FusedLayerNorm(torch.nn.Module):
def __init__(
self,
hidden_size,
eps=1e-5,
persist_layer_norm=True,
sequence_parallel=False,
zero_centered_gamma=False,
):
super().__init__()
self.zero_centered_gamma = zero_centered_gamma
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [
1024,
1536,
2048,
2304,
3072,
3840,
4096,
5120,
6144,
8192,
10240,
12288,
12800,
15360,
16384,
18432,
20480,
24576,
25600,
30720,
32768,
40960,
49152,
65536,
]
if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
persist_layer_norm = False
if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM:
# TODO: Add pytorch only layer norm
raise ValueError(f'Apex must currently be installed to use megatron core.')
if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = eps
self.weight = Parameter(torch.Tensor(*hidden_size))
self.bias = Parameter(torch.Tensor(*hidden_size))
self.reset_parameters()
self.persist_layer_norm = persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self):
if self.zero_centered_gamma:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
if self.persist_layer_norm:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(
inp=output, requires_grad=input.requires_grad, keep_graph=True
)
else:
output = FusedLayerNormAffineFunction.apply(
input, weight, self.bias, self.hidden_size, self.eps
)
return output
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn as nn
from megatron.core.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
return ScaledMaskedSoftmax.apply(input, mask, scale)
else:
return ScaledSoftmax.apply(input, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert (
len(batch_idx) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class ModelParallelConfig:
"""Base configuration for Megatron Core
Model Parallelism
-----------------
tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1.
pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU
ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by
reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient
Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for
more details. Defaults to None.
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
Initialization
--------------
perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you
know you are going to load values from a checkpoint.
use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU.
Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False.
Training
--------
fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False.
bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False.
params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32
timers (optional, default=None): TODO
Optimizations
-------------
gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
Defaults to False.
async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of
tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False.
Pipeline Parallelism
--------------------
pipeline_dtype (required): dtype used in p2p communication, usually params_dtype
grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the
scaled loss. If None, no function is called on the loss.
enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False.
autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype.
variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this
communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches
where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window
of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline
parallelism will overlap with computation. Must be False if batch_p2p_comm is true.
batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False
if overlap_p2p_comm is True.
batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work
around a bug in older version of PyTorch.
use_ring_exchange_p2p (bool, default = False): Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange.
deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent
to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used.
no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel
communication. If the model is an instance of torch.nn.DistributedDataParallel, the default is to use
torch.nn.DistributedDataParallel.no_sync.
grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer
gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are
to be synchronized.
param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed
optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be
synchronized.
"""
# Model parallelism
tensor_model_parallel_size: int = 1
pipeline_model_parallel_size: int = 1
virtual_pipeline_model_parallel_size: Optional[int] = None
sequence_parallel: bool = False
# Initialization
perform_initialization: bool = True
use_cpu_initialization: bool = False
# Training
fp16: bool = False
bf16: bool = False
params_dtype: torch.dtype = torch.float32
timers: Callable = None
# Optimizations
gradient_accumulation_fusion: bool = False
async_tensor_model_parallel_allreduce: bool = False
# Pipeline Parallel
pipeline_dtype: torch.dtype = None
grad_scale_func: Callable = None
enable_autocast: bool = False
autocast_dtype: torch.dtype = None
variable_seq_lengths: bool = False
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
overlap_p2p_comm: bool = False
batch_p2p_comm: bool = True
batch_p2p_sync: bool = True
use_ring_exchange_p2p: bool = False
deallocate_pipeline_outputs: bool = False
no_sync_func: Callable = None
grad_sync_func: Callable = None
param_sync_func: Callable = None
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.async_tensor_model_parallel_allreduce:
# sequence_parallelism already does this async
self.async_tensor_model_parallel_allreduce = False
if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype
# coding=utf-8
# The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \
# 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \
# common/megatron/rotary_pos_embedding.py
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib.util
import torch
import torch
from torch import einsum, nn
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
def __init__(self, dim, seq_len_interpolation_factor=None):
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
if importlib.util.find_spec('einops') is None:
raise RuntimeError("einops is required for Rotary Embedding")
self.register_buffer('inv_freq', inv_freq, persistent=False)
def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
if self.seq_len_interpolation_factor is not None:
seq = seq.type_as(self.inv_freq)
seq *= 1 / self.seq_len_interpolation_factor
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
from einops import rearrange
return rearrange(emb, 'n d -> n 1 1 d')
return emb[:, None, None, :]
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict.pop(f'{prefix}inv_freq', None)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
from einops import rearrange
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
......
from .gpt_model import GPTModel
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core import tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import (
make_sharded_tensor_for_checkpoint,
make_tp_sharded_tensor_for_checkpoint,
)
class GPTEmbedding(MegatronModule):
"""Language model embeddings.
Arguments:
config (TransformerConfig): config object with all necessary configs for TransformerBlock
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This
is used for positional embedding
add_position_embedding (bool): Add a position embedding.
embedding_dropout_prob float): dropout probability for embeddings
"""
def __init__(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
add_position_embedding: bool,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = add_position_embedding
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
config=self.config,
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
def forward(self, input_ids, position_ids):
# Embeddings.
word_embeddings = self.word_embeddings(input_ids)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
def sharded_state_dict(self, prefix=''):
sharded_state_dict = {}
word_embeddings_prefix = f'{prefix}word_embeddings.'
word_embeddings_state_dict = self.word_embeddings.state_dict(
prefix=word_embeddings_prefix, keep_vars=True
)
sharded_word_embeddings_key = f'{word_embeddings_prefix}weight'
sharded_word_embeddings_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=word_embeddings_state_dict[sharded_word_embeddings_key],
key=sharded_word_embeddings_key,
allow_shape_mismatch=True,
)
sharded_state_dict[sharded_word_embeddings_key] = sharded_word_embeddings_tensor
if self.add_position_embedding:
position_embeddings_prefix = f'{prefix}position_embeddings.'
position_embeddings_state_dict = self.position_embeddings.state_dict(
prefix=position_embeddings_prefix, keep_vars=True
)
sharded_position_embeddings_key = f'{position_embeddings_prefix}weight'
sharded_position_embeddings_tensor = make_sharded_tensor_for_checkpoint(
tensor=position_embeddings_state_dict[sharded_position_embeddings_key],
key=sharded_position_embeddings_key,
)
sharded_state_dict[sharded_position_embeddings_key] = sharded_position_embeddings_tensor
return sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import Literal, Optional
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.gpt.gpt_embedding import GPTEmbedding
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class GPTModel(MegatronModule):
"""Transformer language model.
Arguments:
config (TransformerConfig): transformer config
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
shared. Defaults to False.
position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
):
super(GPTModel, self).__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# Embeddings.
if self.pre_process:
self.embedding = GPTEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
add_position_embedding=(self.position_embedding_type == 'learned_absolute'),
)
# Rotary Position Embeddings
if self.position_embedding_type == 'rope':
rotary_dim = self.config.kv_channels
if rotary_percent < 1.0:
rotary_dim = int(rotary_dim * rotary_percent)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim, seq_len_interpolation_factor)
else:
self.rotary_pos_emb = None
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
self_attn_mask_type=AttnMaskType.causal,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
)
if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
self.initialize_last_stage_with_word_embeddings()
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params=None,
):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.rotary_pos_emb is not None:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if self.decoder.input_tensor is not None:
rotary_seq_len = self.decoder.input_tensor.size(0)
else:
rotary_seq_len = decoder_input.size(0)
# Decoder input is split along sequence dimension, but RoPE is applied in tensor parallel region
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
def shared_embedding_or_output_weight(self):
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.output_layer.weight
return None
def initialize_last_stage_with_word_embeddings(self):
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism and sharing word
# embeddings. Nothing to do if we aren't sharing weights or aren't using
# pipeline parallelism.
if not self.share_embeddings_and_output_weights or (self.pre_process and self.post_process):
return
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.output_layer.weight.data.fill_(0)
self.output_layer.weight.shared = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(GPTModel, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
GPTModel.embedding_warning_printed = True
def sharded_state_dict(self, prefix=''):
sharded_state_dict = {}
if self.pre_process:
embedding_prefix = f'{prefix}embedding.'
embedding_sharded_state_dict = self.embedding.sharded_state_dict(
prefix=embedding_prefix
)
sharded_state_dict.update(embedding_sharded_state_dict)
decoder_prefix = f'{prefix}decoder.'
decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix)
sharded_state_dict.update(decoder_sharded_state_dict)
if self.post_process:
output_layer_prefix = f'{prefix}output_layer.'
output_layer_key = f'{output_layer_prefix}weight'
if self.share_embeddings_and_output_weights:
if not self.pre_process:
# when sharing embeddings with last stage, we need to use the weights from the first stage
# on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
tensor = self.shared_embedding_or_output_weight()
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
dp_rank = parallel_state.get_data_parallel_rank()
dp_size = parallel_state.get_data_parallel_world_size()
last_stage_word_emb_replica_id = (
dp_rank + dp_size
) # copy of first stage embedding
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_key] = sharded_output_layer_tensor
else:
output_layer_state_dict = self.output_layer.state_dict(
prefix=output_layer_prefix, keep_vars=True
)
output_layer_tensor = output_layer_state_dict[output_layer_key]
# independent output layer
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_tensor,
key=output_layer_key,
replica_id=parallel_state.get_data_parallel_rank(),
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_key] = sharded_output_layer_tensor
return sharded_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment