Unverified Commit bbb2c21f authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[shardformer] fix chatglm implementation (#5644)

* [shardformer] fix chatglm policy

* [shardformer] fix chatglm flash attn

* [shardformer] update readme

* [shardformer] fix chatglm init

* [shardformer] fix chatglm test

* [pipeline] fix chatglm merge batch
parent 5d88ef1a
...@@ -7,7 +7,7 @@ from torch.nn import Module ...@@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_forward(output_obj) self.send_forward(output_obj)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward( def run_forward_backward(
...@@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step( def forward_backward_step(
......
...@@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer: ...@@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
- [x] Unit Testing - [x] Unit Testing
- [ ] Policy Implementation - [ ] Policy Implementation
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | |:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | | t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | | llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | | opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | | whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | | falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
## 💡 API Design ## 💡 API Design
...@@ -391,6 +391,43 @@ _POLICY_LIST = { ...@@ -391,6 +391,43 @@ _POLICY_LIST = {
} }
``` ```
#### How to support those models in huggingface model hub but not in the transformers library
There are two cases:
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
E.g. for llama:
```python
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
```
for chatglm2:
```python
policy["GLMBlock"] = ModulePolicyDescription(...)
```
Then when registering such models in the autopolicy, we should follow below format:
```python
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
file_name="<policy_filename>", class_name="<policy_class_name>"
)
```
As for chatglm2 model, it should be:
```python
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
)
```
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
### Write Your Unit Testing ### Write Your Unit Testing
This section serves as the guideline for testing the `shardformer` module. This section serves as the guideline for testing the `shardformer` module.
...@@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate ...@@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length. We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows. In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model | | N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: | |:-----:|:---------:|:-----------:|
| 256 | 11.2ms | 17.2ms | | 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms | | 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms | | 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms | | 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms | | 4096 | 160.5ms | 90.4ms |
<p align="center"> <p align="center">
...@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows. ...@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
In the case of using 4 GPUs, the training times are as follows. In the case of using 4 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model | | N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: | |:-----:|:---------:|:-----------:|
| 256 | 10.0ms | 21.1ms | | 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms | | 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms | | 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms | | 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms | | 4096 | 160.4ms | 68.0ms |
...@@ -475,10 +512,10 @@ warmup_fraction = 0.03 ...@@ -475,10 +512,10 @@ warmup_fraction = 0.03
| accuracy | f1 | loss | GPU number | model sharded | | accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: | |:--------:|:-------:|:-------:|:----------:|:-------------:|
| 0.82971 | 0.87713 | 0.23194 | 4 | True | | 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True | | 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False | | 0.84521 | 0.88700 | 0.21822 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence. Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
...@@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
) )
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: # try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = module.weight.shape[0] normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = True elementwise_affine = getattr(module, "elementwise_affine", True)
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
rmsnorm = FusedRMSNormWithHook( rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
) )
rmsnorm.weight = module.weight rmsnorm.weight = module.weight
......
...@@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
...@@ -31,7 +30,12 @@ def get_flash_core_attention_forward(): ...@@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
device=query_layer.device, device=query_layer.device,
) )
temp_mask = ( temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device) torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0) .tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1) .expand(query_layer.shape[0], 1, -1, -1)
) )
...@@ -49,6 +53,7 @@ def get_flash_core_attention_forward(): ...@@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
attention_mask=attn_bias, attention_mask=attn_bias,
attention_mask_type=attention_mask_type, attention_mask_type=attention_mask_type,
dropout_p=dropout_p, dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
) )
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
...@@ -115,7 +120,7 @@ class ChatGLMPipelineForwards: ...@@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
@staticmethod @staticmethod
def chatglm_model_forward( def chatglm_model_forward(
self: ChatGLMModel, self: "ChatGLMModel",
input_ids, input_ids,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
...@@ -194,7 +199,9 @@ class ChatGLMPipelineForwards: ...@@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather": if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
...@@ -224,7 +231,9 @@ class ChatGLMPipelineForwards: ...@@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather": if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -254,7 +263,7 @@ class ChatGLMPipelineForwards: ...@@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
@staticmethod @staticmethod
def chatglm_for_conditional_generation_forward( def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration, self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
......
...@@ -151,10 +151,10 @@ _POLICY_LIST = { ...@@ -151,10 +151,10 @@ _POLICY_LIST = {
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
), ),
# ChatGLM # ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( "transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy" file_name="chatglm2", class_name="ChatGLMModelPolicy"
), ),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
), ),
# Falcon # Falcon
...@@ -202,6 +202,13 @@ def _fullname(obj): ...@@ -202,6 +202,13 @@ def _fullname(obj):
module = klass.__module__ module = klass.__module__
if module == "builtins": if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str' return klass.__qualname__ # avoid outputs like 'builtins.str'
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__ return module + "." + klass.__qualname__
...@@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy: ...@@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
if policy_location is None: if policy_location is None:
raise NotImplementedError( raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location) policy = import_policy(policy_location)
......
...@@ -7,7 +7,6 @@ from torch import Tensor ...@@ -7,7 +7,6 @@ from torch import Tensor
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..modeling.chatglm2 import ( from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn, get_chatglm_sequence_parallel_forward_fn,
...@@ -17,7 +16,11 @@ from ..modeling.chatglm2 import ( ...@@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] __all__ = [
"ChatGLMPolicy",
"ChatGLMModelPolicy",
"ChatGLMForConditionalGenerationPolicy",
]
class ChatGLMPolicy(Policy): class ChatGLMPolicy(Policy):
...@@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy): ...@@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {} policy = {}
embedding_cls = None embedding_cls = None
...@@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy): ...@@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock] = ModulePolicyDescription( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs = {
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
}
if self.model.config.multi_query_attention:
assert (
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
attn_kwargs["self_attention.qkv_hidden_size"] = (
self.model.config.kv_channels * self.model.config.num_attention_heads
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
) // self.shard_config.tensor_parallel_size
policy["GLMBlock"] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
...@@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy): ...@@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
self.model.config.kv_channels * self.model.config.num_attention_heads self.model.config.kv_channels * self.model.config.num_attention_heads
) )
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads * self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
**attn_kwargs,
}, },
param_replacement=[], param_replacement=[],
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense",
...@@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy): ...@@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
), ),
], ],
policy=policy, policy=policy,
target_key=ChatGLMModel, target_key="ChatGLMModel",
) )
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
...@@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy): ...@@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
), ),
], ],
policy=policy, policy=policy,
target_key=GLMBlock, target_key="GLMBlock",
) )
if self.model.config.post_layer_norm: if self.model.config.post_layer_norm:
...@@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy): ...@@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
) )
], ],
policy=policy, policy=policy,
target_key=ChatGLMModel, target_key="ChatGLMModel",
) )
# use flash attention # use flash attention
...@@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy): ...@@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
"forward": get_flash_core_attention_forward(), "forward": get_flash_core_attention_forward(),
}, },
policy=policy, policy=policy,
target_key=CoreAttention, target_key="CoreAttention",
) )
# use sequence parallel # use sequence parallel
...@@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy): ...@@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy, policy=policy,
target_key=ChatGLMModel, target_key="ChatGLMModel",
) )
# use jit fused operator # use jit fused operator
...@@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy): ...@@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
"dropout_add": get_jit_fused_dropout_add_func(), "dropout_add": get_jit_fused_dropout_add_func(),
}, },
policy=policy, policy=policy,
target_key=GLMBlock, target_key="GLMBlock",
) )
return policy return policy
...@@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy): ...@@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
stage_index = stage_manager.get_stage_index(layers_per_stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy): ...@@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy model_cls="ChatGLMModel",
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
policy=policy,
) )
return policy return policy
...@@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): ...@@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration, model_cls="ChatGLMForConditionalGeneration",
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy, policy=policy,
) )
......
...@@ -310,13 +310,6 @@ if dist.get_world_size() > 1: ...@@ -310,13 +310,6 @@ if dist.get_world_size() > 1:
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. 2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
## How Shardformer Works ## How Shardformer Works
### Main Idea ### Main Idea
......
...@@ -303,13 +303,6 @@ if dist.get_world_size() > 1: ...@@ -303,13 +303,6 @@ if dist.get_world_size() > 1:
2. 当使用Shardformer处理`GPT2ForSequenceClassification``ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 2. 当使用Shardformer处理`GPT2ForSequenceClassification``ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
并且使用这些导入的类初始化模型。
## Shardformer的工作原理 ## Shardformer的工作原理
......
import torch import torch
from torch.nn import init
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from transformers import AutoConfig, AutoModelForCausalLM
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..registry import ModelAttribute, model_zoo from ..registry import ModelAttribute, model_zoo
...@@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( ...@@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
) )
loss_fn = lambda x: x["loss"] loss_fn = lambda x: x["loss"]
config = ChatGLMConfig( config = AutoConfig.from_pretrained(
"THUDM/chatglm2-6b",
trust_remote_code=True,
num_layers=2, num_layers=2,
padded_vocab_size=65024, padded_vocab_size=65024,
hidden_size=64, hidden_size=64,
ffn_hidden_size=214,
num_attention_heads=8, num_attention_heads=8,
kv_channels=16, kv_channels=16,
rmsnorm=True, rmsnorm=True,
original_rope=True, original_rope=True,
use_cache=True, use_cache=True,
multi_query_attention=False,
torch_dtype=torch.float32, torch_dtype=torch.float32,
) )
infer_config = ChatGLMConfig(
infer_config = AutoConfig.from_pretrained(
"THUDM/chatglm2-6b",
trust_remote_code=True,
num_layers=2, num_layers=2,
padded_vocab_size=65024, padded_vocab_size=65024,
hidden_size=128, hidden_size=128,
...@@ -60,18 +66,18 @@ infer_config = ChatGLMConfig( ...@@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
torch_dtype=torch.float32, torch_dtype=torch.float32,
) )
model_zoo.register(
name="transformers_chatglm", def init_chatglm():
model_fn=lambda: ChatGLMModel(config, empty_init=False), model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
data_gen_fn=data_gen, for m in model.modules():
output_transform_fn=output_transform_fn, if m.__class__.__name__ == "RMSNorm":
loss_fn=loss_fn_for_chatglm_model, init.ones_(m.weight)
model_attribute=ModelAttribute(has_control_flow=True), return model
)
model_zoo.register( model_zoo.register(
name="transformers_chatglm_for_conditional_generation", name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), model_fn=init_chatglm,
data_gen_fn=data_gen_for_conditional_generation, data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn, loss_fn=loss_fn,
......
...@@ -227,7 +227,7 @@ def check_output_hidden_state( ...@@ -227,7 +227,7 @@ def check_output_hidden_state(
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
def check_weight( def check_weight(
......
...@@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import ( ...@@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin, build_model_from_hybrid_plugin,
check_all_grad_tensors, check_all_grad_tensors,
check_loss, check_loss,
check_output_hidden_state,
check_weight, check_weight,
get_grad_tensors_for_check, get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin, run_forward_backward_with_hybrid_plugin,
...@@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
# if org_model.__class__.__name__ == "ChatGLMModel": if org_model.__class__.__name__ == "ChatGLMModel":
# check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
...@@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 4, "tp_size": 4,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment