Unverified Commit a27e0bb4 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] bert support sequence parallel. (#4455)

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

[shardformer] bert support sequence parallel

* [shardformer] bert support sequence parallel
parent 0ecd71e0
......@@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
......@@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# do all gather in default stream
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient in calculate_stream
......@@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None):
# all gather
input_ = input_.contiguous()
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
......
This diff is collapsed.
......@@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
......@@ -47,13 +48,14 @@ class BertPolicy(Policy):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertLayer,
BertModel,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
......@@ -69,14 +71,17 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
......@@ -85,6 +90,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
......@@ -93,10 +99,12 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dropout",
......@@ -115,6 +123,12 @@ class BertPolicy(Policy):
)
])
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BertModel)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
......@@ -205,7 +219,13 @@ class BertPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment