strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
@@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
...
...
@@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few reserach works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format.
The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few research works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format.