Unverified Commit a27e0bb4 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] bert support sequence parallel. (#4455)

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel
parent 0ecd71e0
...@@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
...@@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# do all gather in default stream # do all gather in default stream
input_ = input_.contiguous() input_ = input_.contiguous()
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient in calculate_stream # calculate gradient in calculate_stream
...@@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None):
# all gather # all gather
input_ = input_.contiguous() input_ = input_.contiguous()
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group) torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat # concat
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn ...@@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.bert import ( from ..modeling.bert import (
BertPipelineForwards, BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward, get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward, get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward, get_jit_fused_bert_self_output_forward,
...@@ -47,13 +48,14 @@ class BertPolicy(Policy): ...@@ -47,13 +48,14 @@ class BertPolicy(Policy):
from transformers.models.bert.modeling_bert import ( from transformers.models.bert.modeling_bert import (
BertEmbeddings, BertEmbeddings,
BertLayer, BertLayer,
BertModel,
BertOutput, BertOutput,
BertSelfAttention, BertSelfAttention,
BertSelfOutput, BertSelfOutput,
) )
policy = {} policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size": "attention.self.all_head_size":
...@@ -69,14 +71,17 @@ class BertPolicy(Policy): ...@@ -69,14 +71,17 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.query", suffix="attention.self.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.key", suffix="attention.self.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.value", suffix="attention.self.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.dropout", suffix="attention.self.dropout",
...@@ -85,6 +90,7 @@ class BertPolicy(Policy): ...@@ -85,6 +90,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dropout", suffix="attention.output.dropout",
...@@ -93,10 +99,12 @@ class BertPolicy(Policy): ...@@ -93,10 +99,12 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dropout", suffix="output.dropout",
...@@ -115,6 +123,12 @@ class BertPolicy(Policy): ...@@ -115,6 +123,12 @@ class BertPolicy(Policy):
) )
]) ])
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BertModel)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
# Handle bert layer # Handle bert layer
...@@ -205,7 +219,13 @@ class BertPolicy(Policy): ...@@ -205,7 +219,13 @@ class BertPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(description=method_replacement,
policy=policy, policy=policy,
target_key=model_cls) target_key=model_cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment