If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py).
```python
fromcolossalai.shardformerimportPolicy
You should do:
1. Inherit Policy class
2. Overwrite `argument_policy` method
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
-`attr_dict` is dict contains all the attributes need to be modified in this layer.
-`param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
3. Overwrite `inject_policy` method (Optional)
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
- The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.
More details can be found in shardformer/policies/basepolicy.py
<b>This diagram is deprecated, need to update it</b>
</p>
@staticmethod
defmlp_in()->Union[List,None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
returnNone
@staticmethod
defmlp_out()->Union[List,None]:
r"""
4h -> h mlp layer
### Distributed Modules
Returns:
List[Layer]: List of layer object
"""
returnNone
`ShardFormer` replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model.
- CLASS `Layer`:
Parameters:
- suffix: (str): the suffix of the layer to indicate the attribute of the layer.
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
- ignore (bool): Whether to ignore this layer if it is not in the model
- reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed.
- n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight.
```python
@dataclass
classShardConfig:
data_parallel_size:int
tensor_parallel_size:int
...
This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer.
# Some possible future config fields
pipeline_parallel_size:int# Support pipeline parallelism
tensor_parallel_mode:Choice['1d','2d','2.5d','3d']# support different tensor parallel mode
inference_only:bool# only inject inference-suitable sharding policy
gather_output:bool# gather the model output
use_flash_attention:bool# whether to use flash attention to speed up attention
```
CLASS `Col_Layer(Layer)`:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
### Policy
This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists.
The `Policy` class describes how to handle the model sharding.
It is merely a description, the actual sharding will be performed by `ModelSharder`.
We abstract the policy into four stages:
CLASS `Row_Layer(Layer)`:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row.
``` python
@dataclass
classModulePolicyDescription:
"""
Describe how the attributes and parameters will be transformed in a policy
- CLASS `Policy`:
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class.
These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
"""
suffix:str
target_module:ParallelModule
kwargs:Dict[str,Any]=None
- `Policy.argument_policy()`
In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.
classPolicy(ABC):
- `Policy.inject_policy()`
def__init__(self)
self.model=None
This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.
defset_model(self,model:nn.Module)->None:
"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
"""
self.model=model
- `Policy.binding_policy()`
@abstractmethod
defpreprocess(self)->nn.Module:
"""
Perform some preprocessing on the model, such as resizing the embedding size
"""
...
This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.
replace the class of the model to substitute the forward and attributes
"""
...
- CLASS `ModelSharder(model, policy)`:
@abstractmethods
defpostprocess(self)->nn.Module:
"""
Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head
"""
...
```
This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
- `ModelShard.inject_model()`
### ModelSharder
This function is used to inject the model to modify the forward and backward progress.
`ModelSharder` is the class in charge of sharding the model based on the given policy.
- `ModelShard.replace_layer()`
```python
classModelSharder:
This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...
This function is used to help different layers share weight or bias.
defreplace_model_class(self)->None:
"""
Replace the model's methods and attributes with our own defined class.
E.g. we can replace the forward function of the original BertForMaskedLM object
with the forward function we define in BertForMaskedLM_ class.
"""
...
- CLASS `Slicer`:
defreplace_module(self)->None:
"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
"""
...
```
This class is used to slice tensor according to policy.
### User-facing API
We only expose a limited number of APIs to the user to keep their user experience simple and clean.
3. DistCrossEntropy Loss
- Overview
```python
classShardFormer:
"""
Parallelize model based on the given config and policy
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
model = shard_former.shard_model(model, policy=policy)
dataloader = shard_former.shard_dataset(dataset)
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
"""
- Step
def__init__(self,shard_config:ShardConfig):
"""
Do two things:
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
2. serve as a store for shard config
"""
self.shard_config=shard_config
self.pg_manager=None
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large