"vscode:/vscode.git/clone" did not exist on "bce9499ed33f7a8359bbb568c7ee18d72e8aa731"
Unverified Commit 7172459e authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)



* [shardformer] implement policy for all GPT-J models and test

* [shardformer] support interleaved pipeline parallel for bert finetune

* [shardformer] shardformer support falcon (#4883)

* [shardformer]: fix interleaved pipeline for bert model (#5048)

* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)

* Add Mistral support for Shardformer (#5103)

* [shardformer] add tests to mistral (#5105)

---------
Co-authored-by: default avatarPengtai Xu <henryxu880@gmail.com>
Co-authored-by: default avatarppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: default avatarflybird11111 <1829166702@qq.com>
Co-authored-by: default avatareric8607242 <e0928021388@gmail.com>
parent 126cf180
...@@ -79,7 +79,7 @@ jobs: ...@@ -79,7 +79,7 @@ jobs:
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10 timeout-minutes: 20
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
cancel-in-progress: true cancel-in-progress: true
......
...@@ -22,7 +22,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt ...@@ -22,7 +22,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils
...@@ -911,6 +911,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -911,6 +911,8 @@ class HybridParallelPlugin(PipelinePluginBase):
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
""" """
def __init__( def __init__(
...@@ -946,6 +948,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -946,6 +948,8 @@ class HybridParallelPlugin(PipelinePluginBase):
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = True,
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (
...@@ -972,17 +976,38 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -972,17 +976,38 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) self.stage_manager = PipelineStageManager(
self.schedule = OneForwardOneBackwardSchedule( self.pg_mesh,
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size pipeline_axis=PP_AXIS,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
) )
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
else:
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
......
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .ophooks import BaseOpHook, register_ophooks_recursively from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr from .stateful_tensor_mgr import StatefulTensorMgr
...@@ -11,4 +12,6 @@ __all__ = [ ...@@ -11,4 +12,6 @@ __all__ = [
"AutoTensorPlacementPolicy", "AutoTensorPlacementPolicy",
"register_ophooks_recursively", "register_ophooks_recursively",
"BaseOpHook", "BaseOpHook",
"ColoInitContext",
"post_process_colo_init_ctx",
] ]
...@@ -3,7 +3,7 @@ from typing import Any, Callable, Iterable, List, Optional, Union ...@@ -3,7 +3,7 @@ from typing import Any, Callable, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
from torch.nn import Module from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
...@@ -16,18 +16,25 @@ from .base import PipelineSchedule ...@@ -16,18 +16,25 @@ from .base import PipelineSchedule
class InterleavedSchedule(PipelineSchedule): class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: def __init__(
self.num_model_chunks = num_model_chunks self,
assert ( stage_manager: PipelineStageManager,
num_microbatches % self.num_model_chunks == 0 num_model_chunks: int,
), "Number of microbatches should be an integer multiple of number of model chunks" num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
) -> None:
super().__init__(stage_manager) super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatch = num_microbatch
self.batch: Optional[Any] = None self.microbatch_size = microbatch_size
self.batch_size: Optional[int] = None self.num_model_chunks = num_model_chunks
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None self.batch: Any
self.batch_size: int
self.microbatch_offset: List[int]
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -42,8 +49,22 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -42,8 +49,22 @@ class InterleavedSchedule(PipelineSchedule):
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" if self.num_microbatch is not None:
self.microbatch_size = self.batch_size // self.num_microbatches assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
elif self.microbatch_size is not None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
else:
raise ValueError("Either num_microbatch or microbatch_size should be provided")
assert (
self.num_microbatch % self.num_model_chunks == 0
), "Number of microbatch should be an integer multiple of number of model chunks"
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
def load_micro_batch(self, model_chunk_id: int) -> Any: def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
...@@ -58,7 +79,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -58,7 +79,7 @@ class InterleavedSchedule(PipelineSchedule):
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number. """Helper method to get the model chunk ID given the iteration number.
Args: Args:
...@@ -70,36 +91,10 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -70,36 +91,10 @@ class InterleavedSchedule(PipelineSchedule):
""" """
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not forward: if not is_forward:
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def is_first_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the first stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
return True
return False
def is_last_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the last stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
return True
return False
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B. For interleaved 1F1B.
...@@ -111,7 +106,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -111,7 +106,7 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.is_first_stage(model_chunk_id): if self.stage_manager.is_first_stage(model_chunk_id):
input_tensor = None input_tensor = None
else: else:
input_tensor = self.comm.recv_forward(prev_rank) input_tensor = self.comm.recv_forward(prev_rank)
...@@ -129,7 +124,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -129,7 +124,7 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.is_last_stage(model_chunk_id): if self.stage_manager.is_last_stage(model_chunk_id):
output_tensor_grad = None output_tensor_grad = None
else: else:
output_tensor_grad = self.comm.recv_backward(next_rank) output_tensor_grad = self.comm.recv_backward(next_rank)
...@@ -145,7 +140,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -145,7 +140,7 @@ class InterleavedSchedule(PipelineSchedule):
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.is_last_stage(model_chunk_id): if not self.stage_manager.is_last_stage(model_chunk_id):
self.comm.send_forward(output_object, next_rank) self.comm.send_forward(output_object, next_rank)
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
...@@ -157,12 +152,12 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -157,12 +152,12 @@ class InterleavedSchedule(PipelineSchedule):
input_object (Any): Object to be sent. input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.is_first_stage(model_chunk_id): if not self.stage_manager.is_first_stage(model_chunk_id):
self.comm.send_backward(input_object, prev_rank) self.comm.send_backward(input_object, prev_rank)
def forward_step( def forward_step(
self, self,
model_chunk: Module, model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
input_obj: Optional[dict], input_obj: Optional[dict],
criterion: Callable, criterion: Callable,
...@@ -171,7 +166,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -171,7 +166,7 @@ class InterleavedSchedule(PipelineSchedule):
) -> Union[torch.Tensor, dict]: ) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline """Forward one step of the pipeline
Args: Args:
model (Module): Model Chunk to be run model (ModuleList or Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss. criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
...@@ -184,10 +179,19 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -184,10 +179,19 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
if self.is_last_stage(model_chunk_id): self.stage_manager.model_chunk_id = model_chunk_id
loss = criterion(output_obj, micro_batch) / self.num_microbatches if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else:
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
self.stage_manager.model_chunk_id = None
if self.stage_manager.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
if outputs is not None: if outputs is not None:
...@@ -243,17 +247,17 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -243,17 +247,17 @@ class InterleavedSchedule(PipelineSchedule):
def forward_backward_step( def forward_backward_step(
self, self,
model_chunk: Module, model_chunk: Union[ModuleList, Module],
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[..., Any], criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None, optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False, return_outputs: bool = False,
) -> dict: ) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages. """Runs interleaved schedule, with communication between pipeline stages.
Args: Args:
model_chunk (List[Module]): Model Chunk to be trained. model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator. data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
...@@ -263,49 +267,46 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -263,49 +267,46 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
# TODO: handle arbitrary batch size when forward_only == True
forward_only = not torch.is_grad_enabled() forward_only = not torch.is_grad_enabled()
if optimizer is None: if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward." assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
# num_warmup_microbatches is the step when not all the processes are working num_microbatch = self.num_microbatch * self.num_model_chunks
num_microbatches = self.num_microbatches * num_model_chunks
if forward_only: if forward_only:
num_warmup_microbatches = num_microbatches num_warmup_microbatch = num_microbatch
else: else:
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches num_microbatch_remaining = num_microbatch - num_warmup_microbatch
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_objs = None input_objs = None
output_objs = None output_objs = None
if not forward_only: if not forward_only:
input_objs = [[] for _ in range(num_model_chunks)] input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage(-1):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
else: else:
accum_loss = None accum_loss = None
# for ranks except the first one, get into recv state # for ranks except the first one, get into recv state
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
input_obj = self.recv_forward(0) input_obj = self.recv_forward(0)
input_objs[0].append(input_obj)
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)
# recv first on first rank to avoid sending or recving at the same time # Run warmup forward passes.
if self.stage_manager.is_first_stage(): for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# recv first on first rank to avoid sending or receiving at the same time
if self.stage_manager.is_first_stage(-1):
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
...@@ -315,21 +316,20 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -315,21 +316,20 @@ class InterleavedSchedule(PipelineSchedule):
else: else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only: if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch:
break break
else:
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id) model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
if not forward_only: input_obj = self.recv_forward(model_chunk_id)
input_objs[model_chunk_id].append(input_obj)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatch_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == num_microbatch_remaining - 1
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only: if forward_only:
...@@ -344,7 +344,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -344,7 +344,7 @@ class InterleavedSchedule(PipelineSchedule):
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id) output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for # Pop output_obj and output_obj from the start of the list for
...@@ -358,23 +358,25 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -358,23 +358,25 @@ class InterleavedSchedule(PipelineSchedule):
if last_iteration: if last_iteration:
input_obj = None input_obj = None
else: else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad) self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
for i in range(num_microbatches_remaining, num_microbatches): for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
input_obj = input_objs[model_chunk_id].pop(0) input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id) output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad) self.send_backward(model_chunk_id, input_obj_grad)
if not forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
...@@ -19,7 +19,15 @@ class PipelineStageManager: ...@@ -19,7 +19,15 @@ class PipelineStageManager:
stage (int): The current stage. stage (int): The current stage.
""" """
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: def __init__(
self,
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.pg_mesh = pg_mesh self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None self.prev_rank: Optional[Tuple[int, ...]] = None
...@@ -43,29 +51,62 @@ class PipelineStageManager: ...@@ -43,29 +51,62 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
if is_virtual: self.is_interleave = enable_interleave
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank # add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]: if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self) -> bool: # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""Is the current stage the first stage. """Is the current stage the first stage.
NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device()
Returns: Returns:
bool: Whether the current stage is the first stage. bool: Whether the current stage is the first stage.
""" """
return self.stage == 0 if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
def is_last_stage(self) -> bool: assert self.is_interleave ^ (
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == 0
else:
return self.stage == 0 and model_chunk_id == 0
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""Is the current stage the last stage. """Is the current stage the last stage.
NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device()
Returns: Returns:
bool: Whether the current stage is the last stage. bool: Whether the current stage is the last stage.
""" """
return self.stage == self.num_stages - 1 if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1
@property @property
def num_stages(self) -> int: def num_stages(self) -> int:
......
...@@ -127,6 +127,7 @@ We will follow this roadmap to develop Shardformer: ...@@ -127,6 +127,7 @@ We will follow this roadmap to develop Shardformer:
| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] |
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
| falcon | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
...@@ -136,6 +137,7 @@ We will follow this roadmap to develop Shardformer: ...@@ -136,6 +137,7 @@ We will follow this roadmap to develop Shardformer:
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
## 💡 API Design ## 💡 API Design
......
...@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm):
) )
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm": if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0] normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon eps = module.variance_epsilon
elementwise_affine = True elementwise_affine = True
......
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.models.falcon.modeling_falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
build_alibi_tensor,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_falcon_alibi_tensor(
self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
if dist.is_initialized():
world_size = dist.get_world_size(process_group)
num_heads = num_heads * world_size
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(
1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size(process_group))
offset = dist.get_rank(process_group) * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset : num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
return build_falcon_alibi_tensor
def get_tp_falcon_decoder_layer_forward():
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, dropout_add
def forward(
self: FalconDecoderLayer,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attn_outputs = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)
outputs = attn_outputs[1:]
# MLP.
mlp_output = self.mlp(mlp_layernorm_out)
if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output = mlp_output + attention_output
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
return forward
def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
def forward(
self: FalconAttention,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
if alibi is not None:
attention_mask_float = (
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
tgt_len = key_layer_.size()[1]
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
context_layer = me_attention(
query_layer_,
key_layer_,
value_layer_,
attn_bias=attention_mask_float,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
batch_size, seq_length, _, _ = context_layer.shape
context_layer = context_layer.reshape(batch_size, seq_length, -1)
output_tensor = self.dense(context_layer)
return output_tensor, present
return forward
class FalconPipelineForwards:
"""
This class serves as a micro library for falcon pipeline forwards.
"""
@staticmethod
def falcon_model_forward(
self: FalconModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
else:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# case: First stage of training
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if stage_manager.is_last_stage():
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
else:
# always return dict for imediate stage
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_causal_lm_forward(
self: FalconForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_sequence_classification_forward(
self: FalconForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
batch_size = hidden_states.shape[0]
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_token_classification_forward(
self: FalconForTokenClassification,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_question_answering_forward(
self: FalconForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.gptj.modeling_gptj import (
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
apply_rotary_pos_emb,
get_embed_positions,
)
from transformers.utils import is_torch_fx_proxy, logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
class GPTJPipelineForwards:
"""
This class serves as a micro library for forward function substitution of GPTJ models
under pipeline setting.
"""
@staticmethod
def gptj_model_forward(
self: GPTJModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, BaseModelOutputWithPast]:
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward.
# Please refer to original code of transformers for more details.
# GPTJ has no cross attention in comparison to GPT2
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
input_shape = input_ids.size()
input_ids = input_ids.view(-1, seq_length)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# position id to be asssigned not just for the first stage for attn input
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
else:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if stage_manager.is_first_stage():
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=None,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
else:
# always return dict for intermediate stage
return {"hidden_states": hidden_states}
@staticmethod
def gptj_causallm_model_forward(
self: GPTJForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForCausalLM.forward.
# Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = GPTJPipelineForwards.gptj_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
# If not at the last stage, return hidden_states as in GPTJModel
if not stage_manager.is_last_stage():
return {"hidden_states": transformer_outputs["hidden_states"]}
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def gptj_for_sequence_classification_forward(
self: GPTJForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward.
# Please refer to original code of transformers for more details.
"""
logger = logging.get_logger(__name__)
if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
else:
batch_size, _ = hidden_states.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = GPTJPipelineForwards.gptj_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
# If not at the last stage, return hidden_states as in GPTJModel
if not stage_manager.is_last_stage():
return {"hidden_states": transformer_outputs["hidden_states"]}
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def gptj_for_question_answering_forward(
self: GPTJForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering.forward.
# Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPTJPipelineForwards.gptj_model_forward(
self.transformer,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
# If not at the last stage, return hidden_states as in GPTJModel
if not stage_manager.is_last_stage():
return {"hidden_states": outputs["hidden_states"]}
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_gptj_flash_attention_forward():
from transformers.models.gptj.modeling_gptj import GPTJAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary or len(tensor.shape) in [4, 5]:
return tensor
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def forward(
self: GPTJAttention,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = split_heads(query, self.num_attention_heads, self.head_dim, True)
key = split_heads(key, self.num_attention_heads, self.head_dim, True)
value = split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
# key = key.permute(0, 2, 1, 3)
# query = query.permute(0, 2, 1, 3)
key = key.to(dtype=value.dtype) # fp16 compatability
query = query.to(dtype=value.dtype)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache is True:
present = (key, value)
else:
present = None
# use AttnMaskType and ColoAttention
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
# use coloattention
scale = value.size(-1) ** -0.5
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale
)
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)
return outputs # a, present, (attentions)
return forward
def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward
from typing import Optional, Tuple
import torch
def get_mistral_flash_attention_forward():
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: MistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = (
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
...@@ -85,6 +85,17 @@ _POLICY_LIST = { ...@@ -85,6 +85,17 @@ _POLICY_LIST = {
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(
file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"
), ),
# GPTJ
"transformers.models.gptj.modeling_gptj.GPTJModel": PolicyLocation(file_name="gptj", class_name="GPTJModelPolicy"),
"transformers.models.gptj.modeling_gptj.GPTJForCausalLM": PolicyLocation(
file_name="gptj", class_name="GPTJForCausalLMPolicy"
),
"transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering": PolicyLocation(
file_name="gptj", class_name="GPTJForQuestionAnsweringPolicy"
),
"transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification": PolicyLocation(
file_name="gptj", class_name="GPTJForSequenceClassificationPolicy"
),
# ViT # ViT
"transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
"transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation(
...@@ -146,6 +157,31 @@ _POLICY_LIST = { ...@@ -146,6 +157,31 @@ _POLICY_LIST = {
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
), ),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForCausalLM": PolicyLocation(
file_name="falcon", class_name="FalconForCausalLMPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForSequenceClassification": PolicyLocation(
file_name="falcon", class_name="FalconForSequenceClassificationPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForTokenClassification": PolicyLocation(
file_name="falcon", class_name="FalconForTokenClassificationPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation(
file_name="mistral", class_name="MistralForCausalLMPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
...@@ -214,13 +214,32 @@ class Policy(ABC): ...@@ -214,13 +214,32 @@ class Policy(ABC):
return layers_per_stage return layers_per_stage
@staticmethod @staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: def get_stage_index(
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
""" """
get the start index and end index of layers for each stage. Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
""" """
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
start_idx = num_layers_per_stage_accumulated[stage] stage_indices = []
end_idx = num_layers_per_stage_accumulated[stage + 1] for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
return [start_idx, end_idx] return stage_indices[0] if num_model_chunks == 1 else stage_indices
...@@ -21,7 +21,7 @@ __all__ = [ ...@@ -21,7 +21,7 @@ __all__ = [
"BertPolicy", "BertPolicy",
"BertModelPolicy", "BertModelPolicy",
"BertForPreTrainingPolicy", "BertForPreTrainingPolicy",
"BertLMdHeadModelPolicy", "BertLMHeadModelPolicy",
"BertForMaskedLMPolicy", "BertForMaskedLMPolicy",
"BertForNextSentencePredictionPolicy", "BertForNextSentencePredictionPolicy",
"BertForSequenceClassificationPolicy", "BertForSequenceClassificationPolicy",
...@@ -249,15 +249,34 @@ class BertPolicy(Policy): ...@@ -249,15 +249,34 @@ class BertPolicy(Policy):
return self.model return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface """
to customized forward method, and add this changing to policy.""" If under pipeline parallel setting, replacing the original forward method of huggingface
if self.pipeline_stage_manager: to customized forward method, and add this changing to policy.
stage_manager = self.pipeline_stage_manager """
if self.model.__class__.__name__ == "BertModel": if self.pipeline_stage_manager is None:
module = self.model return
else:
module = self.model.bert stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "BertModel":
module = self.model
else:
module = self.model.bert
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
...@@ -265,11 +284,8 @@ class BertPolicy(Policy): ...@@ -265,11 +284,8 @@ class BertPolicy(Policy):
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
) )
} }
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
...@@ -282,13 +298,32 @@ class BertPolicy(Policy): ...@@ -282,13 +298,32 @@ class BertPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_interleave:
if stage_manager.is_first_stage(): assert stage_manager.num_model_chunks is not None
held_layers.append(module.embeddings) layers_per_stage = self.distribute_layers(
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
held_layers.extend(module.encoder.layer[start_idx:end_idx]) )
if stage_manager.is_last_stage(): stage_indices = Policy.get_stage_index(
held_layers.append(module.pooler) layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(-1):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(-1):
held_layers.append(module.pooler)
else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
return held_layers return held_layers
...@@ -464,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): ...@@ -464,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
......
...@@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe ...@@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
class BloomPolicy(Policy): class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Bloom model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
......
import warnings
from functools import partial
from typing import Callable, Dict, List
from torch import Tensor, nn
from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from ..modeling.falcon import (
FalconPipelineForwards,
build_falcon_alibi_tensor_fn,
get_falcon_flash_attention_forward,
get_tp_falcon_decoder_layer_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["FalconPolicy"]
class FalconPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Falcon model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel
if not self.model.config.new_decoder_architecture and self.model.config.multi_query:
warnings.warn(
"Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag."
)
self.shard_config.enable_tensor_parallelism = False
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
policy = {}
if self.shard_config.enable_tensor_parallelism:
attn_attribute_replacement = {
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attention.num_kv_heads": self.model.config.num_kv_heads // self.shard_config.tensor_parallel_size,
}
policy[FalconDecoderLayer] = ModulePolicyDescription(
attribute_replacement=attn_attribute_replacement,
method_replacement={"forward": get_tp_falcon_decoder_layer_forward()},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row),
],
)
policy[FalconModel] = ModulePolicyDescription(
attribute_replacement={
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
method_replacement={
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
],
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# handle falcon model
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=FalconModel,
)
# handle falcon decoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="ln_mlp", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="input_layernorm", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=FalconDecoderLayer,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={"forward": get_falcon_flash_attention_forward()},
policy=policy,
target_key=FalconAttention,
)
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "FalconModel":
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "FalconModel":
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
class FalconModelPolicy(FalconPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.falcon.modeling_falcon import FalconModel
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconModel, new_forward=FalconPipelineForwards.falcon_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
"""
get pipeline layers for current stage
"""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""no shared params in falcon model"""
return []
class FalconForCausalLMPolicy(FalconPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconForCausalLM
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
policy=policy,
target_key=FalconForCausalLM,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconForCausalLM,
new_forward=FalconPipelineForwards.falcon_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
falcon_model = self.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if id(falcon_model.transformer.word_embeddings.weight) == id(falcon_model.lm_head.weight):
# tie weights
return [
{
0: falcon_model.transformer.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: falcon_model.lm_head.weight,
}
]
return []
class FalconForSequenceClassificationPolicy(FalconPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconForSequenceClassification
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
policy=policy,
target_key=FalconForSequenceClassification,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconForSequenceClassification,
new_forward=FalconPipelineForwards.falcon_for_sequence_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in falcon for sequence classification model"""
return []
class FalconForTokenClassificationPolicy(FalconPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconForTokenClassification
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=FalconForTokenClassification,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconForTokenClassification,
new_forward=FalconPipelineForwards.falcon_for_token_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in falcon for token classification model"""
return []
class FalconForQuestionAnsweringPolicy(FalconPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconForQuestionAnswering
policy = super().module_policy()
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
),
policy=policy,
target_key=FalconForQuestionAnswering,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconForQuestionAnswering,
new_forward=FalconPipelineForwards.falcon_for_question_answering_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in falcon for question answering model"""
return []
import warnings
from functools import partial
from typing import Callable, Dict, List
from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
"GPTJPolicy",
"GPTJModelPolicy",
"GPTJForCausalLMPolicy",
"GPTJForSequenceClassificationPolicy",
"GPTJForQuestionAnsweringPolicy",
"FlaxGPTJPolicy",
"FlaxGPTJForCausalLMPolicy",
]
class GPTJPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPTJModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[GPTJBlock] = ModulePolicyDescription(
attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_attention_heads": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
policy=policy,
target_key=GPTJModel,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=GPTJBlock,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_gptj_flash_attention_forward(),
},
policy=policy,
target_key=GPTJAttention,
)
return policy
def postprocess(self):
return self.model
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "GPTJModel":
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "GPTJModel":
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# GPTJModel
class GPTJModelPolicy(GPTJPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJModel
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2Model."""
return []
# GPTJForCausalLM
class GPTJForCausalLMPolicy(GPTJPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPTJForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
)
]
)
}
policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""The weights of wte and lm_head are shared."""
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
return []
# GPTJForSequenceClassification
class GPTJForSequenceClassificationPolicy(GPTJPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJForSequenceClassification
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=GPTJForSequenceClassification,
new_forward=GPTJPipelineForwards.gptj_for_sequence_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPTJForSequenceClassification."""
return []
# GPTJForQuestionAnswering
class GPTJForQuestionAnsweringPolicy(GPTJPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJForQuestionAnswering
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=GPTJForQuestionAnswering,
new_forward=GPTJPipelineForwards.gptj_for_question_answering_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2ForQuestionAnswering."""
return []
import warnings
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from ..modeling.mistral import get_mistral_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"]
class MistralPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
policy[MistralDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
],
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key=MistralModel,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MistralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MistralModel,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_mistral_flash_attention_forward(),
},
policy=policy,
target_key=MistralAttention,
)
return policy
def postprocess(self):
return self.model
class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")
return super().module_policy()
class MistralForCausalLMPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")
policy.update(new_item)
return policy
class MistralForSequenceClassificationPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForSequenceClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
MistralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")
policy.update(new_item)
return policy
...@@ -22,6 +22,15 @@ __all__ = [ ...@@ -22,6 +22,15 @@ __all__ = [
class OPTPolicy(Policy): class OPTPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
......
...@@ -26,6 +26,15 @@ __all__ = [ ...@@ -26,6 +26,15 @@ __all__ = [
class WhisperPolicy(Policy): class WhisperPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Whisper model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment