# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict=Dict[str,Any]
ShardedStateDict=Dict[str,Any]
ReplicaId=Union[int,Tuple[int,...]]
classShardedBase(ABC):
key:str
data:object
replica_id:ReplicaId
@dataclass
classShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range: specifies a slice that should be applied to a flattened tensor with `local_shape` in order to get the tensor stored as `data`
f'Axis shape ({axis_sh}) not divisible'f' by axis fragmentation ({axis_fragm}'
)
axis_chunk_size=axis_sh//axis_fragm
chunks.append(axis_chunk_size)
returntuple(chunks)
defwithout_data(self):
returnreplace(self,data=None)
@classmethod
deffrom_rank_offsets(
cls,
key:str,
data:torch.Tensor,
*rank_offsets:Tuple[int,int,int],
replica_id:ReplicaId=0,
prepend_axis_num:int=0,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key: unique key
data: local tensor data
rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into `axis_fragm` fragment along `axis` axis, then local tensor data corresponds to the `axis_rank_offset` chunk.