Unverified Commit efba0f44 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge pull request #4612 from hpcaitech/feature/shardformer

[shardformer] update hybrid parallel plugin and fix bugs
parents ac178ca5 fae6c92e
...@@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import ( ...@@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import ( from ._operation import (
gather_forward_split_backward, gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
reduce_forward, reduce_forward,
split_forward_gather_backward, split_forward_gather_backward,
...@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule): ...@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`): weight_initializer (`typing.Callable`):
...@@ -69,6 +73,9 @@ class Linear1D_Col(ParallelModule): ...@@ -69,6 +73,9 @@ class Linear1D_Col(ParallelModule):
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
gather_output: bool = False, gather_output: bool = False,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
...@@ -80,6 +87,9 @@ class Linear1D_Col(ParallelModule): ...@@ -80,6 +87,9 @@ class Linear1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
...@@ -180,6 +190,11 @@ class Linear1D_Col(ParallelModule): ...@@ -180,6 +190,11 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True,
self.seq_parallel_dim, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output: if self.gather_output:
...@@ -203,6 +218,8 @@ class Linear1D_Row(ParallelModule): ...@@ -203,6 +218,8 @@ class Linear1D_Row(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
...@@ -221,6 +238,8 @@ class Linear1D_Row(ParallelModule): ...@@ -221,6 +238,8 @@ class Linear1D_Row(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
...@@ -238,6 +257,8 @@ class Linear1D_Row(ParallelModule): ...@@ -238,6 +257,8 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
...@@ -373,6 +394,10 @@ class Linear1D_Row(ParallelModule): ...@@ -373,6 +394,10 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = F.linear(input_, self.weight) output_parallel = F.linear(input_, self.weight)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
self.seq_parallel_dim)
else:
output = reduce_forward(output_parallel, self.process_group) output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
......
...@@ -10,6 +10,7 @@ import torch.nn as nn ...@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
...@@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC): ...@@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC):
""" """
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
param_ = param if keep_vars else param.detach() destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
elif is_customized_distributed_tensor(param_):
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
else:
destination[prefix + name] = param_
for name, buf in self._buffers.items(): for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
......
...@@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import ( ...@@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import ( from ._operation import (
gather_forward_split_backward, gather_forward_split_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm, matmul_with_async_comm,
reduce_backward, reduce_backward,
reduce_forward, reduce_forward,
...@@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV). n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
...@@ -173,6 +176,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -173,6 +176,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False, async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3, n_fused: int = 3,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
...@@ -185,6 +190,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -185,6 +190,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.n_fused = n_fused self.n_fused = n_fused
...@@ -296,13 +303,17 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -296,13 +303,17 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
assert input_.shape[-1] == self.weight.shape[0], \ assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1]) input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
# input_parallel = input_
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1, self.overlap)
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
self.async_communication) self.async_communication)
...@@ -329,6 +340,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -329,6 +340,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
...@@ -346,6 +358,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -346,6 +358,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
seq_parallel: bool = False,
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
...@@ -363,6 +376,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -363,6 +376,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
...@@ -499,6 +513,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -499,6 +513,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group) output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
......
import math import math
import warnings import warnings
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import ( ...@@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
class BertPipelineForwards: class BertPipelineForwards:
...@@ -56,6 +58,7 @@ class BertPipelineForwards: ...@@ -56,6 +58,7 @@ class BertPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
# TODO(jianghai): add explaination of the output here. # TODO(jianghai): add explaination of the output here.
r""" r"""
...@@ -177,6 +180,17 @@ class BertPipelineForwards: ...@@ -177,6 +180,17 @@ class BertPipelineForwards:
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
# layer_outputs # layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None layer_outputs = hidden_states if hidden_states is not None else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config is not None and shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0: if stage_manager.is_first_stage() and idx == 0:
encoder_attention_mask = encoder_extended_attention_mask encoder_attention_mask = encoder_extended_attention_mask
...@@ -223,11 +237,17 @@ class BertPipelineForwards: ...@@ -223,11 +237,17 @@ class BertPipelineForwards:
all_cross_attentions = all_cross_attentions + \ all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],) (layer_outputs[2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config is not None and shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
# end of a stage loop # end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None sequence_output = hidden_states if hidden_states is not None else None
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
...@@ -268,6 +288,7 @@ class BertPipelineForwards: ...@@ -268,6 +288,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -294,6 +315,7 @@ class BertPipelineForwards: ...@@ -294,6 +315,7 @@ class BertPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None, hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
...@@ -350,6 +372,7 @@ class BertPipelineForwards: ...@@ -350,6 +372,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
...@@ -404,7 +427,8 @@ class BertPipelineForwards: ...@@ -404,7 +427,8 @@ class BertPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None, hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
all_self_attentions = None all_self_attentions = None
...@@ -457,6 +481,7 @@ class BertPipelineForwards: ...@@ -457,6 +481,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -491,6 +516,7 @@ class BertPipelineForwards: ...@@ -491,6 +516,7 @@ class BertPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
...@@ -532,6 +558,7 @@ class BertPipelineForwards: ...@@ -532,6 +558,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs, **kwargs,
): ):
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
...@@ -594,7 +621,8 @@ class BertPipelineForwards: ...@@ -594,7 +621,8 @@ class BertPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -636,6 +664,7 @@ class BertPipelineForwards: ...@@ -636,6 +664,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...@@ -666,7 +695,8 @@ class BertPipelineForwards: ...@@ -666,7 +695,8 @@ class BertPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -726,6 +756,7 @@ class BertPipelineForwards: ...@@ -726,6 +756,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -742,8 +773,7 @@ class BertPipelineForwards: ...@@ -742,8 +773,7 @@ class BertPipelineForwards:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(self.bert,
self.bert,
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -756,7 +786,7 @@ class BertPipelineForwards: ...@@ -756,7 +786,7 @@ class BertPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
) shard_config=shard_config)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -799,6 +829,7 @@ class BertPipelineForwards: ...@@ -799,6 +829,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...@@ -843,6 +874,7 @@ class BertPipelineForwards: ...@@ -843,6 +874,7 @@ class BertPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -886,6 +918,7 @@ class BertPipelineForwards: ...@@ -886,6 +918,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
# NOTE: the arg start_position and end_position are used only for the last stage # NOTE: the arg start_position and end_position are used only for the last stage
r""" r"""
...@@ -909,8 +942,7 @@ class BertPipelineForwards: ...@@ -909,8 +942,7 @@ class BertPipelineForwards:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(self.bert,
self.bert,
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -923,7 +955,7 @@ class BertPipelineForwards: ...@@ -923,7 +955,7 @@ class BertPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
) shard_config=shard_config)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1101,3 +1133,153 @@ def get_jit_fused_bert_output_forward(): ...@@ -1101,3 +1133,153 @@ def get_jit_fused_bert_output_forward():
return hidden_states return hidden_states
return forward return forward
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward(embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
# When sequence parallelism done, gather the output tensor in forward and split it in backward
sequence_output = gather_forward_split_backward(sequence_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
return forward
...@@ -23,6 +23,10 @@ from transformers.models.bloom.modeling_bloom import ( ...@@ -23,6 +23,10 @@ from transformers.models.bloom.modeling_bloom import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
...@@ -111,6 +115,7 @@ class BloomPipelineForwards: ...@@ -111,6 +115,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments, **deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']:
...@@ -205,6 +210,13 @@ class BloomPipelineForwards: ...@@ -205,6 +210,13 @@ class BloomPipelineForwards:
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
) )
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]),
start=start_idx): start=start_idx):
...@@ -248,6 +260,12 @@ class BloomPipelineForwards: ...@@ -248,6 +260,12 @@ class BloomPipelineForwards:
all_self_attentions = all_self_attentions + \ all_self_attentions = all_self_attentions + \
(outputs[2 if use_cache else 1],) (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
# Add last hidden state # Add last hidden state
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -287,6 +305,7 @@ class BloomPipelineForwards: ...@@ -287,6 +305,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments): **deprecated_arguments):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -327,7 +346,8 @@ class BloomPipelineForwards: ...@@ -327,7 +346,8 @@ class BloomPipelineForwards:
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index) stage_index=stage_index,
shard_config=shard_config)
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
all_self_attentions = None all_self_attentions = None
...@@ -380,6 +400,7 @@ class BloomPipelineForwards: ...@@ -380,6 +400,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments, **deprecated_arguments,
): ):
r""" r"""
...@@ -424,6 +445,7 @@ class BloomPipelineForwards: ...@@ -424,6 +445,7 @@ class BloomPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
...@@ -503,6 +525,7 @@ class BloomPipelineForwards: ...@@ -503,6 +525,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments, **deprecated_arguments,
): ):
r""" r"""
...@@ -547,6 +570,7 @@ class BloomPipelineForwards: ...@@ -547,6 +570,7 @@ class BloomPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
...@@ -597,6 +621,7 @@ class BloomPipelineForwards: ...@@ -597,6 +621,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...@@ -632,6 +657,7 @@ class BloomPipelineForwards: ...@@ -632,6 +657,7 @@ class BloomPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
all_hidden_states = None all_hidden_states = None
...@@ -700,8 +726,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False): ...@@ -700,8 +726,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
fused_qkv = self.query_key_value(hidden_states) fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = hidden_states.size() batch_size, tgt_len, _ = query_layer.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
_, kv_length, _, _ = key_layer.size() _, kv_length, _, _ = key_layer.size()
...@@ -896,3 +921,156 @@ def get_jit_fused_bloom_gelu_forward(): ...@@ -896,3 +921,156 @@ def get_jit_fused_bloom_gelu_forward():
return self.bloom_gelu_forward(x, bias) return self.bloom_gelu_forward(x, bias)
return forward return forward
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
from transformers import BloomModel
def forward(
self: BloomModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward
...@@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu ...@@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration, ChatGLMForConditionalGeneration,
...@@ -146,6 +148,7 @@ class ChatGLMPipelineForwards: ...@@ -146,6 +148,7 @@ class ChatGLMPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_hidden_states = (output_hidden_states output_hidden_states = (output_hidden_states
...@@ -198,6 +201,11 @@ class ChatGLMPipelineForwards: ...@@ -198,6 +201,11 @@ class ChatGLMPipelineForwards:
all_self_attentions = None all_self_attentions = None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
if output_hidden_states: if output_hidden_states:
...@@ -214,6 +222,11 @@ class ChatGLMPipelineForwards: ...@@ -214,6 +222,11 @@ class ChatGLMPipelineForwards:
hidden_states, kv_cache = layer_ret hidden_states, kv_cache = layer_ret
if use_cache: if use_cache:
presents = presents + (kv_cache,) presents = presents + (kv_cache,)
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
...@@ -233,8 +246,7 @@ class ChatGLMPipelineForwards: ...@@ -233,8 +246,7 @@ class ChatGLMPipelineForwards:
return {'hidden_states': hidden_states} return {'hidden_states': hidden_states}
@staticmethod @staticmethod
def chatglm_for_conditional_generation_forward( def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
...@@ -249,7 +261,7 @@ class ChatGLMPipelineForwards: ...@@ -249,7 +261,7 @@ class ChatGLMPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
): shard_config: ShardConfig = None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
...@@ -266,6 +278,7 @@ class ChatGLMPipelineForwards: ...@@ -266,6 +278,7 @@ class ChatGLMPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -296,3 +309,91 @@ class ChatGLMPipelineForwards: ...@@ -296,3 +309,91 @@ class ChatGLMPipelineForwards:
) )
else: else:
return transformer_outputs return transformer_outputs
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward(inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
hidden_states = gather_forward_split_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
if not return_dict:
return tuple(v for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward
This diff is collapsed.
This diff is collapsed.
...@@ -125,9 +125,9 @@ _POLICY_LIST = { ...@@ -125,9 +125,9 @@ _POLICY_LIST = {
# ChatGLM # ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"),
} }
......
...@@ -11,17 +11,12 @@ from torch.nn import Module ...@@ -11,17 +11,12 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig from ..shard.shard_config import ShardConfig
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
class ParallelModule():
def __init__(self):
pass
@dataclass @dataclass
class SubModuleReplacementDescription: class SubModuleReplacementDescription:
r""" r"""
......
...@@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn ...@@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.bert import ( from ..modeling.bert import (
BertPipelineForwards, BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward, get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward, get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward, get_jit_fused_bert_self_output_forward,
...@@ -47,13 +48,15 @@ class BertPolicy(Policy): ...@@ -47,13 +48,15 @@ class BertPolicy(Policy):
from transformers.models.bert.modeling_bert import ( from transformers.models.bert.modeling_bert import (
BertEmbeddings, BertEmbeddings,
BertLayer, BertLayer,
BertModel,
BertOutput, BertOutput,
BertSelfAttention, BertSelfAttention,
BertSelfOutput, BertSelfOutput,
) )
policy = {} policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size": "attention.self.all_head_size":
...@@ -69,14 +72,26 @@ class BertPolicy(Policy): ...@@ -69,14 +72,26 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.query", suffix="attention.self.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.key", suffix="attention.self.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.value", suffix="attention.self.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.dropout", suffix="attention.self.dropout",
...@@ -85,6 +100,7 @@ class BertPolicy(Policy): ...@@ -85,6 +100,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dropout", suffix="attention.output.dropout",
...@@ -93,10 +109,15 @@ class BertPolicy(Policy): ...@@ -93,10 +109,15 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dropout", suffix="output.dropout",
...@@ -115,6 +136,12 @@ class BertPolicy(Policy): ...@@ -115,6 +136,12 @@ class BertPolicy(Policy):
) )
]) ])
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BertModel)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
# Handle bert layer # Handle bert layer
...@@ -141,20 +168,26 @@ class BertPolicy(Policy): ...@@ -141,20 +168,26 @@ class BertPolicy(Policy):
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_bert_flash_attention_forward(), 'forward': get_bert_flash_attention_forward(),
}) },
policy=policy,
target_key=BertSelfAttention)
# use jit operator # use jit operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_self_output_forward(), 'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy[BertOutput] = ModulePolicyDescription(method_replacement={ policy=policy,
target_key=BertSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_output_forward(), 'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy=policy,
target_key=BertOutput)
return policy return policy
...@@ -205,7 +238,13 @@ class BertPolicy(Policy): ...@@ -205,7 +238,13 @@ class BertPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(description=method_replacement,
policy=policy, policy=policy,
target_key=model_cls) target_key=model_cls)
......
...@@ -285,34 +285,30 @@ class BlipPolicy(Policy): ...@@ -285,34 +285,30 @@ class BlipPolicy(Policy):
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_blip2_flash_attention_forward(), 'forward': get_blip2_flash_attention_forward(),
}) },
policy=policy,
target_key=Blip2Attention)
# use jit operator # use jit operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( self.append_or_create_method_replacement(description={
method_replacement={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(), 'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ policy=policy,
target_key=Blip2QFormerSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_output_forward(), 'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy=policy,
target_key=Blip2QFormerOutput)
return policy return policy
def postprocess(self): def postprocess(self):
binding_map = {
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model return self.model
......
...@@ -12,6 +12,7 @@ from ..modeling.bloom import ( ...@@ -12,6 +12,7 @@ from ..modeling.bloom import (
BloomPipelineForwards, BloomPipelineForwards,
build_bloom_alibi_tensor_fn, build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward, get_bloom_flash_attention_forward,
get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward, get_jit_fused_bloom_mlp_forward,
...@@ -43,6 +44,8 @@ class BloomPolicy(Policy): ...@@ -43,6 +44,8 @@ class BloomPolicy(Policy):
policy = {} policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
...@@ -53,11 +56,14 @@ class BloomPolicy(Policy): ...@@ -53,11 +56,14 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
), kwargs={'seq_parallel': use_sequence_parallel}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.attention_dropout", suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
...@@ -65,11 +71,14 @@ class BloomPolicy(Policy): ...@@ -65,11 +71,14 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
), kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
), kwargs={'seq_parallel': use_sequence_parallel}),
]) ])
policy[BloomModel] = ModulePolicyDescription( policy[BloomModel] = ModulePolicyDescription(
...@@ -116,26 +125,40 @@ class BloomPolicy(Policy): ...@@ -116,26 +125,40 @@ class BloomPolicy(Policy):
policy=policy, policy=policy,
target_key=BloomBlock) target_key=BloomBlock)
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BloomModel)
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_bloom_flash_attention_forward(), 'forward': get_bloom_flash_attention_forward(),
'dropout_add': get_dropout_add_func() 'dropout_add': get_dropout_add_func(),
}) },
policy=policy,
target_key=BloomAttention)
# enable jit fused operator # enable jit fused operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_attention_forward(), 'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy[BloomMLP] = ModulePolicyDescription(method_replacement={ policy=policy,
target_key=BloomAttention)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_mlp_forward(), 'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(), 'dropout_add': get_jit_fused_dropout_add_func(),
}) },
policy[BloomGelu] = ModulePolicyDescription(method_replacement={ policy=policy,
target_key=BloomMLP)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_gelu_forward(), 'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
}) },
policy=policy,
target_key=BloomGelu)
return policy return policy
...@@ -154,7 +177,13 @@ class BloomPolicy(Policy): ...@@ -154,7 +177,13 @@ class BloomPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(description=method_replacement,
policy=policy, policy=policy,
target_key=model_cls) target_key=model_cls)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment