This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
## 📝 Design
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
## 🔨 Usage
- Create a device mesh
```python
# this is the list of global ranks involved in the device mesh
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model.
You should do:
1. Inherit Policy class
2. Overwrite `argument_policy` method
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
-`attr_dict` is dict contains all the attributes need to be modified in this layer.
-`param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
3. Overwrite `inject_policy` method (Optional)
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.
More details can be found in shardformer/policies/basepolicy.py
This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model.
- CLASS `Layer`:
Parameters:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
- ignore (bool): Whether to ignore this layer if it is not in the model
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
CLASS `Col_Layer(Layer)`:
- gather_output (bool): Whether to gather the output of the layer
This class inherited from `Layer`, representing the layer will be sliced along column.
CLASS `Row_Layer(Layer)`:
This class inherited from `Layer`, representing the layer will be sliced along row.
- CLASS `Policy`:
In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class.
These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.
- `Policy.argument_policy()`
In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.
- `Policy.inject_policy()`
This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.
- `Policy.binding_policy()`
This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.
- CLASS `ModelSharder(model, policy)`:
This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
- `ModelShard.inject_model()`
This function is used to inject the model to modify the forward and backward progress.
- `ModelShard.replace_layer()`
This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.
- `ModelShard.bind_layer()`
This function is used to help different layers share weight or bias.
- CLASS `Slicer`:
This class is used to slice tensor according to policy.
3. DistCrossEntropy Loss
- Overview
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
alse can be represented as:
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
- Step
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large
- Get a mask to mask the logits not in the local device
- Caculate the loss according to the second formula
# part of code modified from https://github.com/tunib-ai/parallelformers
fromdataclassesimportdataclass
fromtypingimportAny,Callable,Dict,List,Tuple,Type
importtorch.nnasnn
@dataclass
classArgument:
r"""
The argument class for the policy
Args:
attr_dict (Dict[str, Any]): The dict for the param setting
param_funcs (:class:`List[Callable]`): The list for the param functions
"""
attr_dict:Dict[str,Any]
param_funcs:List[Callable]
@dataclass
classLayer:
r"""
The layer object for the policy
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight:str=None
bias:str=None
replace_layer:Any=None
ignore:bool=False
reversed:bool=False
n_cast:int=None
@dataclass
classCol_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
Args:
gather_output (bool): Whether to gather the output of the layer
"""
gather_output:bool=False
@dataclass
classRow_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
"""
pass
classPolicy():
r"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
CustomPolicy:
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``