# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict=Dict[str,Any]
ShardedStateDict=Dict[str,Any]
ReplicaId=Union[int,Tuple[int,...]]
classShardedBase(ABC):
key:str
data:object
replica_id:ReplicaId
@abstractmethod
defvalidate_metadata_integrity(self):
"""Codifies the constraints on metadata attributes."""
@abstractmethod
defwithout_data(self)->'ShardedBase':
raiseNotImplementedError
@dataclass
classShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range: specifies a slice that should be applied to a flattened tensor with `local_shape` in order to get the tensor stored as `data`
"""
key:str
data:Optional[torch.Tensor]=field(repr=False)
dtype:torch.dtype
local_shape:Tuple[int,...]
global_shape:Tuple[int,...]
global_offset:Tuple[int,...]
axis_fragmentations:Optional[Tuple[int,...]]
replica_id:ReplicaId=0
prepend_axis_num:int=0
allow_shape_mismatch:bool=False
flattened_range:Optional[slice]=None
def__post_init__(self):
self.validate_metadata_integrity()
defvalidate_metadata_integrity(self)->None:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range=self.flattened_rangeisnotNone
ifself.dataisnotNone:
ifself.data.dtype!=self.dtype:
raiseCheckpointingException(
f'Data dtype should match `dtype` attribute for {self}'
f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}'
)
axis_chunk_size=axis_sh//axis_fragm
chunks.append(axis_chunk_size)
returntuple(chunks)
defwithout_data(self):
returnreplace(self,data=None)
@classmethod
deffrom_rank_offsets(
cls,
key:str,
data:torch.Tensor,
*rank_offsets:Tuple[int,int,int],
replica_id:ReplicaId=0,
prepend_axis_num:int=0,
flattened_range:None=None,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key (str): unique key
data (torch.Tensor): local tensor data
rank_offsets (Tuple[int, int, int]): each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into `axis_fragm` fragment along `axis` axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id (ReplicaId): see ShardedTensor
prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
"""
ifflattened_rangeisnotNone:
raiseValueError(
'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.'