"vscode:/vscode.git/clone" did not exist on "dec24561cf4048219dac98401b70e9fc35e985ad"
Commit 14846934 authored by ver217's avatar ver217
Browse files

Merge branch 'main' into sync/npu

parents 9102d655 5d9a0ae7
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .ophooks import BaseOpHook, register_ophooks_recursively from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr from .stateful_tensor_mgr import StatefulTensorMgr
...@@ -11,4 +12,6 @@ __all__ = [ ...@@ -11,4 +12,6 @@ __all__ = [
"AutoTensorPlacementPolicy", "AutoTensorPlacementPolicy",
"register_ophooks_recursively", "register_ophooks_recursively",
"BaseOpHook", "BaseOpHook",
"ColoInitContext",
"post_process_colo_init_ctx",
] ]
...@@ -4,23 +4,20 @@ ...@@ -4,23 +4,20 @@
import io import io
import pickle import pickle
import re import re
from typing import Any, List, Optional, Union
from collections import namedtuple from collections import namedtuple
from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from packaging.version import Version from packaging.version import Version
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from .stage_manager import PipelineStageManager from .stage_manager import PipelineStageManager
_unpickler = pickle.Unpickler
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
"""transform tensor to object with unpickle. """transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling Info of the device in bytes stream will be modified into current device before unpickling
...@@ -42,27 +39,13 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - ...@@ -42,27 +39,13 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = bytes(buf_array) buf = bytes(buf_array)
io_bytes = io.BytesIO(buf) io_bytes = io.BytesIO(buf)
byte_pickler = _unpickler(io_bytes) byte_pickler = pickle.Unpickler(io_bytes)
unpickle = byte_pickler.load() unpickle = byte_pickler.load()
return unpickle return unpickle
def check_for_nccl_backend(group): # NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
def _broadcast_object_list( def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
): ):
...@@ -70,20 +53,18 @@ def _broadcast_object_list( ...@@ -70,20 +53,18 @@ def _broadcast_object_list(
The only difference is that object will be move to correct device after unpickled. The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src. be updated with data sent from rank src.
Args: Args:
object_list (List[Any]): list of object to broadcast object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast src (int): source rank to broadcast
dst (int): dst rank to broadcast dst (int): dst rank to broadcast
device (:class:`torch.device`): device to do broadcast. current device in default device (:class:`torch.device`): device to do broadcast. current device in default
""" """
if c10d._rank_not_in_group(group): if c10d._rank_not_in_group(group):
c10d._warn_not_in_group("broadcast_object_list") c10d._warn_not_in_group("broadcast_object_list")
return return
is_nccl_backend = check_for_nccl_backend(group) is_nccl_backend = _check_for_nccl_backend(group)
current_device = None current_device = None
if device is not None: if device is not None:
...@@ -131,7 +112,7 @@ def _broadcast_object_list( ...@@ -131,7 +112,7 @@ def _broadcast_object_list(
if my_rank != src: if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor): for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset: offset + obj_size] obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8) obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"): if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu() obj_view = obj_view.cpu()
...@@ -149,80 +130,107 @@ def _broadcast_object_list( ...@@ -149,80 +130,107 @@ def _broadcast_object_list(
object_list[i] = unpickle_object object_list[i] = unpickle_object
def check_device(group): def _check_for_nccl_backend(group):
is_nccl_backend = check_for_nccl_backend(group) pg = group or c10d._get_default_group()
current_device = None # Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
current_device = torch.device("cpu") current_device = torch.device("cpu")
if is_nccl_backend: if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device()) current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend return current_device, is_nccl_backend
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad']) TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
class P2PDataType(Enum):
serialization = 0
tensor = 1
list = 2
dict = 3
def create_send_metadata(
object: Any, strict: bool = True, return_tensor: bool = False
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
"""
Args:
object (Any): object needed to be sent
strict (bool, optional): whether to check if the object is supported for fast send
return_tensor (bool, optional): whether to return tensor objects
"""
objs, tree_spec = tree_flatten(object)
tensor_metadata, tensor_objs = [], []
non_tensor_obj_idx, non_tensor_objs = [], []
for idx, obj in enumerate(objs):
if isinstance(obj, torch.Tensor):
tensor_objs.append(obj)
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
else:
non_tensor_obj_idx.append(idx)
non_tensor_objs.append(obj)
@dataclass assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
class P2PMetadata: metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
data_type: P2PDataType return metadata if not return_tensor else (metadata, tensor_objs)
content: Union[List[TensorMetadata], TensorMetadata, Any]
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): def _filling_ops_queue(
obj: Union[torch.Tensor, List[torch.Tensor]],
comm_op: Callable,
comm_rank: int,
ops_queue: List,
group: ProcessGroup,
):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
obj = obj.contiguous() obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
ops_queue.append(op_to_add) ops_queue.append(op_to_add)
else: else:
for tensor_to_comm in obj: for tensor_to_comm in obj:
tensor_to_comm = tensor_to_comm.contiguous() assert isinstance(tensor_to_comm, torch.Tensor)
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) _filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
ops_queue.append(op_to_add)
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
if p2p_metadata.data_type == P2PDataType.tensor:
metadata = p2p_metadata.content
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
return tensor_recv
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
buffer_recv = []
for metadata in p2p_metadata.content:
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
buffer_recv.append(tensor_recv)
return buffer_recv
else:
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
buffer_recv = []
for metadata in tensor_metadata:
tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
buffer_recv.append(tensor_recv)
return buffer_recv
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
def _batch_send_recv_tensor(
send_tensor_list: Optional[List[torch.Tensor]],
recv_tensor_metadata: Optional[List[TensorMetadata]],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None buffer_recv = None
if recv_tensor_metadata is not None: if recv_tensor_metadata is not None:
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = [] ops = []
if send_dst is not None and send_tensor_list is not None:
if send_dst is not None: assert send_group is not None
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
if recv_src is not None: assert recv_group is not None
assert buffer_recv is not None _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
# Remove synchronization according to Pytorch's documentation # Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here # However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
...@@ -233,12 +241,16 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re ...@@ -233,12 +241,16 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
def _send_recv_serialization_object( def _send_recv_serialization_object(
object: Any, object: Optional[P2PMetadata],
send_dst: Optional[int], recv_src: Optional[int], send_dst: Optional[int],
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], recv_src: Optional[int],
current_device, send_group: Optional[ProcessGroup],
is_nccl_backend): recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
) -> Optional[P2PMetadata]:
ops = [] ops = []
send_object_tensor = None send_object_tensor = None
if object is not None and send_dst is not None: if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"): if Version(torch.__version__) >= Version("1.13.0"):
...@@ -250,44 +262,40 @@ def _send_recv_serialization_object( ...@@ -250,44 +262,40 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device) send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device) send_object_tensor = send_object_tensor.to(current_device)
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None recv_object_size_tensor = None
if recv_src is not None: if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long) recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend: if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device) recv_object_size_tensor = recv_object_size_tensor.to(current_device)
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor` # See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize() # torch.cuda.synchronize()
ops = [] ops = []
if send_dst is not None and send_object_tensor is not None: if send_dst is not None and send_object_tensor is not None:
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
recv_object_tensor = None recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None: if recv_src is not None and recv_object_size_tensor is not None:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend: if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device) recv_object_tensor = recv_object_tensor.to(current_device)
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor` # See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize() # torch.cuda.synchronize()
...@@ -296,112 +304,119 @@ def _send_recv_serialization_object( ...@@ -296,112 +304,119 @@ def _send_recv_serialization_object(
if recv_object_tensor.device != torch.device("cpu"): if recv_object_tensor.device != torch.device("cpu"):
recv_object_tensor = recv_object_tensor.cpu() recv_object_tensor = recv_object_tensor.cpu()
unpickle_object = _cuda_safe_tensor_to_object( unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
recv_object_tensor, recv_object_size_tensor.item())
if ( if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
unpickle_object = unpickle_object.cuda() unpickle_object = unpickle_object.cuda()
return unpickle_object return unpickle_object
def _check_if_fast_send_available(object):
if type(object) is torch.Tensor:
return True
elif type(object) is list:
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
return is_list_of_tensor
elif type(object) is dict:
is_dict_of_tensor = all([type(k) is str and type(
v) is torch.Tensor for k, v in object.items()])
return is_dict_of_tensor
return False
def _communicate( def _communicate(
object, object: Any,
send_dst: Optional[int], send_dst: Optional[int],
recv_src: Optional[int], recv_src: Optional[int],
send_group: Optional[ProcessGroup] = None, send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any: ) -> Any:
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group): """
c10d._warn_not_in_group("_communicate") Send and receive object from send_dst and recv_src respectively
return
current_send_device, is_send_nccl_backend = check_device(send_group) Args:
current_recv_device, is_recv_nccl_backend = check_device(recv_group) object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
metadata_recv (P2PMetadata, optional): metadata of the object to be received
"""
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
assert (
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
), "metadata_recv should not contain non-tensor objects"
metadata_send, tensor_objs = None, None
if object is not None:
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else:
recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
# 2. recv() [needs extra metadata] and no send()
# 3. neither send() nor recv() need extra metadata
assert not (send_dst is not None and send_metadata) or recv_src is None
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
assert current_send_device == current_recv_device assert current_send_device == current_recv_device
current_device = current_send_device current_device = current_send_device
assert (send_dst is not None) or (recv_src is not None) if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
# Send and receive metadata
_metadata_recv = _send_recv_serialization_object(
object=metadata_send,
send_dst=send_dst if send_metadata else None,
recv_src=recv_src if metadata_recv is None else None,
send_group=send_group if send_metadata else None,
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
)
assert metadata_recv is None or _metadata_recv is None
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
can_fast_send = False # Send and receive data
send_metadata = None recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
if send_dst is not None: recv_tensor_objs = _batch_send_recv_tensor(
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
if not can_fast_send: )
send_metadata = P2PMetadata(P2PDataType.serialization, object)
else:
if type(object) is torch.Tensor:
data_type = P2PDataType.tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif type(object) is list:
data_type = P2PDataType.list
content = []
for v in object:
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
elif type(object) is dict:
data_type = P2PDataType.dict
content = []
for k, v in object.items():
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
else:
raise ValueError('Cannot send object of type {}'.format(type(object)))
send_metadata = P2PMetadata(data_type, content)
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
if recv_metadata is not None:
assert type(recv_metadata) is P2PMetadata
if recv_metadata.data_type == P2PDataType.serialization:
return recv_metadata.content
if not can_fast_send and send_dst is not None:
return
send_tensor_list = None if metadata_recv is not None:
if type(object) is torch.Tensor: assert isinstance(metadata_recv, P2PMetadata)
send_tensor_list = object tree_spec = metadata_recv.tree_spec
elif type(object) is list: non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
send_tensor_list = object non_tensor_objs = metadata_recv.non_tensor_objs
elif type(object) is dict:
send_tensor_list = list(object.values())
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
if recv_metadata is not None:
assert recv_buffer is not None
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
return recv_buffer
elif recv_metadata.data_type == P2PDataType.dict:
return {
k: v
for k, v in zip(
[m.key for m in recv_metadata.content],
recv_buffer,
)
}
else:
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
if recv_tensor_objs is None:
recv_tensor_objs = []
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank """send anything to dst rank
Args: Args:
...@@ -411,10 +426,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: ...@@ -411,10 +426,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
Returns: Returns:
None None
""" """
_communicate(object, send_dst=dst, recv_src=None, send_group=group) _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src """recv anything from src
Args: Args:
...@@ -423,7 +438,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: ...@@ -423,7 +438,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
Returns: Returns:
Any: Object received from src. Any: Object received from src.
""" """
return _communicate(None, send_dst=None, recv_src=src, recv_group=group) return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
def _p2p_comm( def _p2p_comm(
...@@ -436,7 +451,7 @@ def _p2p_comm( ...@@ -436,7 +451,7 @@ def _p2p_comm(
""" """
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
Agrs: Args:
tensor_send_next (torch.Tensor): tensor to be sent to next stage tensor_send_next (torch.Tensor): tensor to be sent to next stage
recv_prev (bool): whether to receive tensor from previous stage recv_prev (bool): whether to receive tensor from previous stage
peer (int): rank of the peer peer (int): rank of the peer
...@@ -467,7 +482,6 @@ def _p2p_comm( ...@@ -467,7 +482,6 @@ def _p2p_comm(
group=group, group=group,
) )
ops.append(recv_prev_op) ops.append(recv_prev_op)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
...@@ -490,7 +504,6 @@ def _p2p_comm( ...@@ -490,7 +504,6 @@ def _p2p_comm(
group=group, group=group,
) )
ops.append(send_next_op) ops.append(send_next_op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp( recv_prev_op = dist.P2POp(
dist.irecv, dist.irecv,
...@@ -510,7 +523,7 @@ class PipelineP2PCommunication: ...@@ -510,7 +523,7 @@ class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None: def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager self.stage_manager = stage_manager
def recv_forward(self, prev_rank: int = None) -> Any: def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args: Args:
...@@ -522,11 +535,16 @@ class PipelineP2PCommunication: ...@@ -522,11 +535,16 @@ class PipelineP2PCommunication:
if prev_rank is None: if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) input_tensor = _recv_object(
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
metadata_recv=metadata_recv,
)
return input_tensor return input_tensor
def recv_backward(self, next_rank: int = None) -> Any: def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args: Args:
...@@ -539,12 +557,15 @@ class PipelineP2PCommunication: ...@@ -539,12 +557,15 @@ class PipelineP2PCommunication:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object( output_tensor_grad = _recv_object(
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
metadata_recv=metadata_recv,
) )
return output_tensor_grad return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None: def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
Args: Args:
...@@ -554,9 +575,15 @@ class PipelineP2PCommunication: ...@@ -554,9 +575,15 @@ class PipelineP2PCommunication:
if next_rank is None: if next_rank is None:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) _send_object(
output_object,
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
send_metadata=send_metadata,
)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
Args: Args:
...@@ -566,9 +593,22 @@ class PipelineP2PCommunication: ...@@ -566,9 +593,22 @@ class PipelineP2PCommunication:
if prev_rank is None: if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) _send_object(
input_object,
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
send_metadata=send_metadata,
)
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any: def send_forward_recv_backward(
self,
input_object: Any,
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
Args: Args:
...@@ -581,11 +621,24 @@ class PipelineP2PCommunication: ...@@ -581,11 +621,24 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate( return _communicate(
input_object, next_rank, next_rank, input_object,
send_group=group, recv_group=group, next_rank,
next_rank,
send_group=group,
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
) )
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any: def send_backward_recv_forward(
self,
input_object: Any,
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args: Args:
...@@ -597,37 +650,23 @@ class PipelineP2PCommunication: ...@@ -597,37 +650,23 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
return _communicate(
input_object, prev_rank, prev_rank,
send_group=group, recv_group=group,
)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender of the tensor
next_rank (int, optional): The rank of the recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate( return _communicate(
input_object, input_object,
send_dst=next_rank, prev_rank,
recv_src=prev_rank, prev_rank,
send_group=send_group, send_group=group,
recv_group=recv_group, recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
) )
def p2p_communicate( def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 self,
output_object: Any,
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> None: ) -> None:
""" """
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
...@@ -636,10 +675,14 @@ class PipelineP2PCommunication: ...@@ -636,10 +675,14 @@ class PipelineP2PCommunication:
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if peer is None: if next_rank is None:
peer = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm( recv_tensor = _p2p_comm(
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
comm_dtype,
) )
return recv_tensor return recv_tensor
from functools import partial from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
from torch.nn import Module from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
...@@ -16,18 +16,35 @@ from .base import PipelineSchedule ...@@ -16,18 +16,35 @@ from .base import PipelineSchedule
class InterleavedSchedule(PipelineSchedule): class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: def __init__(
self.num_model_chunks = num_model_chunks self,
assert ( stage_manager: PipelineStageManager,
num_microbatches % self.num_model_chunks == 0 num_model_chunks: int,
), "Number of microbatches should be an integer multiple of number of model chunks" num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__(stage_manager) super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatch = num_microbatch
self.batch: Optional[Any] = None self.microbatch_size = microbatch_size
self.batch_size: Optional[int] = None self.num_model_chunks = num_model_chunks
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None self.batch: Any
self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int]
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -39,11 +56,37 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -39,11 +56,37 @@ class InterleavedSchedule(PipelineSchedule):
batch = next(data_iter) batch = next(data_iter)
if device is not None: if device is not None:
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" if self.microbatch_size is None:
self.microbatch_size = self.batch_size // self.num_microbatches assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any: def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
...@@ -54,11 +97,12 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -54,11 +97,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: Micro batch. Any: Micro batch.
""" """
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number. """Helper method to get the model chunk ID given the iteration number.
Args: Args:
...@@ -68,38 +112,13 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -68,38 +112,13 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
int: The model chunk idx of the input microbatch_id int: The model chunk idx of the input microbatch_id
""" """
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) assert microbatch_id < self.num_microbatch * self.num_model_chunks
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not forward: if not is_forward:
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def is_first_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the first stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
return True
return False
def is_last_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the last stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
return True
return False
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B. For interleaved 1F1B.
...@@ -111,12 +130,13 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -111,12 +130,13 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.is_first_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
input_tensor = None if not self.stage_manager.is_first_stage():
else: input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
input_tensor = self.comm.recv_forward(prev_rank) if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor return input_tensor
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
...@@ -129,14 +149,15 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -129,14 +149,15 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.is_last_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor_grad = None if not self.stage_manager.is_last_stage():
else: output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
output_tensor_grad = self.comm.recv_backward(next_rank) if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
...@@ -145,10 +166,12 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -145,10 +166,12 @@ class InterleavedSchedule(PipelineSchedule):
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.is_last_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
self.comm.send_forward(output_object, next_rank) if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
...@@ -157,12 +180,102 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -157,12 +180,102 @@ class InterleavedSchedule(PipelineSchedule):
input_object (Any): Object to be sent. input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.is_first_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
self.comm.send_backward(input_object, prev_rank) if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
output_tensor: Any,
next_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_last_stage()
if send_data and recv_data:
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
self.send_forward(model_chunk_id_send, output_tensor)
return self.recv_backward(model_chunk_id_recv)
def send_backward_recv_forward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
input_tensor_grad: Any,
prev_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_first_stage()
if send_data and recv_data:
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only
self.send_backward(model_chunk_id_send, input_tensor_grad)
return self.recv_forward(model_chunk_id_recv)
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
):
if send_prior:
self.send_forward(model_chunk_id_send, output_tensor)
input_tensor = self.recv_forward(model_chunk_id_recv)
else:
input_tensor = self.recv_forward(model_chunk_id_recv)
self.send_forward(model_chunk_id_send, output_tensor)
return input_tensor
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
):
if send_prior:
self.send_backward(model_chunk_id_send, input_tensor_grad)
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
else:
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
self.send_backward(model_chunk_id_send, input_tensor_grad)
return output_tensor_grad
def forward_step( def forward_step(
self, self,
model_chunk: Module, model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
input_obj: Optional[dict], input_obj: Optional[dict],
criterion: Callable, criterion: Callable,
...@@ -171,7 +284,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -171,7 +284,7 @@ class InterleavedSchedule(PipelineSchedule):
) -> Union[torch.Tensor, dict]: ) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline """Forward one step of the pipeline
Args: Args:
model (Module): Model Chunk to be run model (ModuleList or Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss. criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
...@@ -184,17 +297,25 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -184,17 +297,25 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if self.is_last_stage(model_chunk_id): if isinstance(model_chunk, ModuleList):
loss = criterion(output_obj, micro_batch) / self.num_microbatches output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
if accum_loss is not None: else:
accum_loss.add_(loss.detach()) # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
if outputs is not None: internal_inputs = {} if input_obj is None else input_obj
outputs.append(tree_map(detach, output_obj)) internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
return loss output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
else:
return output_obj if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step( def backward_step(
self, self,
...@@ -241,140 +362,211 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -241,140 +362,211 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad input_obj_grad[k] = v.grad
return input_obj_grad return input_obj_grad
def forward_backward_step( def run_forward_only(
self, self,
model_chunk: Module, model_chunk: Union[ModuleList, Module],
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[..., Any], criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False, return_outputs: bool = False,
) -> dict: ) -> Dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages. assert self.forward_only
Args: self.load_batch(data_iter)
model_chunk (List[Module]): Model Chunk to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns: outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) accum_loss = None
num_model_chunks = len(model_chunk) if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
# num_warmup_microbatches is the step when not all the processes are working model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
num_microbatches = self.num_microbatches * num_model_chunks input_obj = self.recv_forward(model_chunk_id)
if forward_only:
num_warmup_microbatches = num_microbatches
else:
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches for i in range(self.num_microbatch * self.num_model_chunks):
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Input, output tensors only need to be saved when doing backward passes if not last_iteration:
input_objs = None input_obj = self.send_forward_recv_forward(
output_objs = None model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
)
else:
self.send_forward(model_chunk_id, output_obj)
if not forward_only: if outputs is not None:
input_objs = [[] for _ in range(num_model_chunks)] outputs = merge_batch(outputs)
output_objs = [[] for _ in range(num_model_chunks)] return {"loss": accum_loss, "outputs": outputs}
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs interleaved schedule, with communication between pipeline stages.
"""
assert not self.forward_only
if return_loss and self.stage_manager.is_last_stage(): self.load_batch(data_iter)
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
# for ranks except the first one, get into recv state num_microbatch = self.num_microbatch * self.num_model_chunks
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
input_obj = self.recv_forward(0) num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
input_objs[0].append(input_obj) num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
# Run warmup forward passes. num_microbatch_remaining = num_microbatch - num_warmup_microbatch
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)
# recv first on first rank to avoid sending or recving at the same time # Input, output tensors only need to be saved when doing backward passes
if self.stage_manager.is_first_stage(): input_objs = [[] for _ in range(self.num_model_chunks)]
input_obj = self.recv_forward(model_chunk_id) output_objs = [[] for _ in range(self.num_model_chunks)]
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
break
else:
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id) outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
# Run 1F1B in steady state. accum_loss = None
for i in range(num_microbatches_remaining): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) accum_loss = torch.scalar_tensor(0, device=get_current_device())
last_iteration = i == (num_microbatches_remaining - 1)
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
last_iteration = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only: input_objs[model_chunk_id].append(input_obj)
self.send_forward(model_chunk_id, output_obj) output_objs[model_chunk_id].append(output_obj)
if not last_iteration: if last_iteration and num_microbatch_remaining == 0:
input_obj = self.recv_forward(model_chunk_id)
else:
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list. else:
input_objs[model_chunk_id].append(input_obj) input_obj = self.send_forward_recv_forward(
output_objs[model_chunk_id].append(output_obj) model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
model_chunk_id = self.get_model_chunk_id(i, forward=False) output_tensor=output_obj,
output_obj_grad = self.recv_backward(model_chunk_id) send_prior=self.stage_manager.stage % 2 == 0,
)
# Pop output_obj and output_obj from the start of the list for if num_microbatch_remaining > 0:
# the backward pass. model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0) output_obj_grad = self.recv_backward(model_chunk_id)
output_obj = output_objs[model_chunk_id].pop(0)
# backward # Run 1F1B in steady state.
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) for i in range(num_microbatch_remaining):
last_iteration = i == num_microbatch_remaining - 1
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward():
if last_iteration and num_microbatch == num_microbatch_remaining:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
self.send_forward(model_chunk_id, output_obj)
else:
output_obj_grad = self.send_forward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
output_tensor=output_obj,
send_prior_fallback=self.stage_manager.stage % 2 == 0,
)
return output_obj_grad
def send_backward_recv_forward():
if last_iteration: if last_iteration:
input_obj = None model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
else: else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) input_obj = self.send_backward_recv_forward(
input_obj = self.recv_forward(model_chunk_id) model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id = self.get_model_chunk_id(i, forward=False) model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
self.send_backward(model_chunk_id, input_obj_grad) input_tensor_grad=input_obj_grad,
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
)
return input_obj
if self.stage_manager.stage % 2 == 0:
output_obj_grad = send_forward_recv_backward()
input_obj = send_backward_recv_forward()
else:
input_obj = send_backward_recv_forward()
output_obj_grad = send_forward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: for i in range(num_microbatch_remaining, num_microbatch):
for i in range(num_microbatches_remaining, num_microbatches): last_iteration = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") _input_obj = input_objs[model_chunk_id].pop(0)
input_obj = input_objs[model_chunk_id].pop(0) _output_obj = output_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0) # output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if not last_iteration:
output_obj_grad = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad) self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
return result
from functools import partial from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
...@@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map ...@@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import ( from ._utils import (
...@@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
stage_manager: PipelineStageManager, stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None: ) -> None:
"""1F1B pipeline schedule. """1F1B pipeline schedule.
...@@ -42,13 +43,21 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -42,13 +43,21 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided" ), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -60,24 +69,45 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -60,24 +69,45 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter) batch = next(data_iter)
if device is not None: if device is not None:
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = 0
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if not self._use_microbatch_size: if self.microbatch_size is None:
assert ( assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches self.microbatch_size = self.batch_size // self.num_microbatches
else: if self.num_microbatches is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size self.num_microbatches = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches
assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any: def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch. """Load a micro batch from the current batch.
Returns: Returns:
Any: Micro batch. Any: Micro batch.
""" """
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
...@@ -92,12 +122,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -92,12 +122,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = None input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
else: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
input_tensor = self.comm.recv_forward(prev_rank) self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor return input_tensor
def recv_backward(self, next_rank: int = None) -> Any: def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
...@@ -109,14 +139,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -109,14 +139,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = None output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
else: if self.enable_metadata_cache and self.grad_metadata_recv is None:
output_tensor_grad = self.comm.recv_backward(next_rank) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None: def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
For 1F1B. For 1F1B.
...@@ -125,20 +155,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -125,20 +155,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For 1F1B. For 1F1B.
...@@ -147,34 +167,60 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -147,34 +167,60 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: def send_forward_recv_backward(
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B. For 1F1B.
Args: Args:
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_last_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank) if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: output_tensor_grad = self.comm.send_forward_recv_backward(
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline. output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B. For 1F1B.
Args: Args:
input_object (Any): Object to be sent. output_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor. prev_rank (int, optional): The rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
""" """
if self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank) if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
elif self.stage_manager.is_last_stage(): send_prior_fallback = None # must not fallback
return self.comm.recv_forward(prev_rank) input_tensor = self.comm.send_backward_recv_forward(
else: input_tensor_grad,
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank) prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
def forward_step( def forward_step(
self, self,
...@@ -254,31 +300,50 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -254,31 +300,50 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad input_obj_grad[k] = v.grad
return input_obj_grad return input_obj_grad
def forward_backward_step( def run_forward_only(
self, self,
model: Module, model: Module,
data_iter: Iterable, data_iter: Iterable,
criterion: Callable[..., Any], criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False, return_loss: bool = False,
return_outputs: bool = False, return_outputs: bool = False,
) -> dict: ) -> Dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """
Runs forward only schedule, with communication between pipeline stages.
"""
assert self.forward_only
Args: self.load_batch(data_iter)
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns: accum_loss = None
dict: A dict with keys: 'loss' and 'outputs'. if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
for _ in range(self.num_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
""" """
forward_only = not torch.is_grad_enabled() Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
if optimizer is None: """
assert forward_only, "Optimizer should be passed when doing backward." assert not self.forward_only
self.load_batch(data_iter) self.load_batch(data_iter)
...@@ -288,30 +353,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -288,30 +353,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_objs = None input_objs, output_objs = [], []
output_objs = None
if not forward_only:
input_objs = []
output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) accum_loss = torch.scalar_tensor(0, device=get_current_device())
else: outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = self.recv_forward() input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj) self.send_forward(output_obj)
input_objs.append(input_obj)
if not forward_only: output_objs.append(output_obj)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -324,44 +379,72 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -324,44 +379,72 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only: output_obj_grad = self.send_forward_recv_backward(
self.send_forward(output_obj) output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
if not last_iteration: # Add input_obj and output_obj to end of list.
input_obj = self.recv_forward() input_objs.append(input_obj)
else: output_objs.append(output_obj)
# TODO adjust here
self.send_forward(output_obj) # Pop output_obj and output_obj from the start of the list for
output_obj_grad = self.recv_backward() # the backward pass.
input_obj = input_objs.pop(0)
# Add input_obj and output_obj to end of list. output_obj = output_objs.pop(0)
input_objs.append(input_obj) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
output_objs.append(output_obj)
if last_iteration:
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
input_obj = self.recv_forward()
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: for i in range(num_warmup_microbatches):
for i in range(num_warmup_microbatches): input_obj = input_objs.pop(0)
input_obj = input_objs.pop(0) output_obj = output_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.recv_backward() output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
if isinstance(model, ModelWrapper): if isinstance(model, ModelWrapper):
model = model.unwrap() model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: Dictionary containing loss and outputs.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
return result
import contextlib
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch.distributed as dist import torch.distributed as dist
...@@ -19,7 +20,15 @@ class PipelineStageManager: ...@@ -19,7 +20,15 @@ class PipelineStageManager:
stage (int): The current stage. stage (int): The current stage.
""" """
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: def __init__(
self,
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.pg_mesh = pg_mesh self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None self.prev_rank: Optional[Tuple[int, ...]] = None
...@@ -43,29 +52,56 @@ class PipelineStageManager: ...@@ -43,29 +52,56 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
if is_virtual: self.is_interleave = enable_interleave
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank # add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]: if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self) -> bool: # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage. """Is the current stage the first stage.
NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
Returns: Returns:
bool: Whether the current stage is the first stage. bool: Whether the current stage is the first stage.
""" """
return self.stage == 0 assert isinstance(ignore_chunk, bool)
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
def is_last_stage(self) -> bool: if not self.is_interleave or ignore_chunk:
return self.stage == 0
else:
return self.stage == 0 and self.model_chunk_id == 0
def is_last_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the last stage. """Is the current stage the last stage.
NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
Returns: Returns:
bool: Whether the current stage is the last stage. bool: Whether the current stage is the last stage.
""" """
return self.stage == self.num_stages - 1 assert isinstance(ignore_chunk, bool)
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property @property
def num_stages(self) -> int: def num_stages(self) -> int:
...@@ -133,3 +169,10 @@ class PipelineStageManager: ...@@ -133,3 +169,10 @@ class PipelineStageManager:
ProcessGroup: Process group of the given stages. ProcessGroup: Process group of the given stages.
""" """
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
@contextlib.contextmanager
def switch_model_chunk_id(self, model_chunk_id: int):
old_model_chunk_id = self.model_chunk_id
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id
...@@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments: ...@@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments:
- `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. - `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalization`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer. - `extra_kwargs`: A dict to store extra kwargs for ShardFormer.
### Write your own policy ### Write your own policy
...@@ -116,17 +116,18 @@ We will follow this roadmap to develop Shardformer: ...@@ -116,17 +116,18 @@ We will follow this roadmap to develop Shardformer:
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | | :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | | t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | | llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | | opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
...@@ -136,6 +137,7 @@ We will follow this roadmap to develop Shardformer: ...@@ -136,6 +137,7 @@ We will follow this roadmap to develop Shardformer:
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
## 💡 API Design ## 💡 API Design
......
...@@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value): ...@@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value):
r""" r"""
Set the element to value of a list object Set the element to value of a list object
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value It used like set_obj_list_element(obj, 'layers[0]', new_layer), it will set obj.layers[0] to value
Args: Args:
obj (object): The object to set obj (object): The object to set
......
...@@ -7,6 +7,12 @@ try: ...@@ -7,6 +7,12 @@ try:
except: except:
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
try:
import fused_weight_gradient_mlp_cuda
_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
class FusedLayerNormAffineFunction1D(torch.autograd.Function): class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm r"""Layernorm
...@@ -141,7 +147,19 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ...@@ -141,7 +147,19 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input) if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
...@@ -214,7 +232,19 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -214,7 +232,19 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# reduce-scatter scheduled first and have GPU resources allocated # reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input) if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter: if ctx.async_grad_reduce_scatter:
...@@ -249,7 +279,20 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -249,7 +279,20 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# calculate gradient # calculate gradient
if len(input_parallel.shape) > 2: if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = grad_output.t().matmul(input_parallel)
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
else:
grad_weight = grad_output.t().matmul(input_parallel)
# grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished # wait until reduce-scatter finished
reducescatter_handle.wait() reducescatter_handle.wait()
...@@ -388,7 +431,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): ...@@ -388,7 +431,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous() input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient # calculate gradient
if len(input_parallel.shape) > 2: if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output) grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished # wait until reduce-scatter finished
reducescatter_handle.wait() reducescatter_handle.wait()
...@@ -473,16 +516,17 @@ class _GatherForwardSplitBackward(torch.autograd.Function): ...@@ -473,16 +516,17 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None return _split(grad_output, ctx.dim, ctx.process_group), None, None
class HookParameter(torch.autograd.Function): class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias): def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias) ctx.save_for_backward(weight, bias)
output = input output = input
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors weight, bias = ctx.saved_tensors
...@@ -491,13 +535,12 @@ class HookParameter(torch.autograd.Function): ...@@ -491,13 +535,12 @@ class HookParameter(torch.autograd.Function):
if bias is not None: if bias is not None:
bias = bias.view(bias.shape) bias = bias.view(bias.shape)
return grad_output, None, None return grad_output, None, None
def hook_paramter_in_backward(input, weight=None, bias=None): def hook_paramter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias) return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group): def _reduce(input_, process_group):
# skip if only one rank involved # skip if only one rank involved
if dist.get_world_size(process_group) == 1: if dist.get_world_size(process_group) == 1:
...@@ -522,7 +565,7 @@ def _split(input_, dim=-1, process_group=None): ...@@ -522,7 +565,7 @@ def _split(input_, dim=-1, process_group=None):
tensor_list = torch.split(input_, dim_size // world_size, dim=dim) tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group) rank = dist.get_rank(process_group)
output = tensor_list[rank].contiguous() output = tensor_list[rank].clone().contiguous()
return output return output
......
...@@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule): ...@@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = F.linear(input_, self.weight) output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
if self.seq_parallel: if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward( output = linear_reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim output_parallel, self.process_group, self.seq_parallel_dim
......
...@@ -78,10 +78,13 @@ class DistCrossEntropy(Function): ...@@ -78,10 +78,13 @@ class DistCrossEntropy(Function):
# calculate the loss # calculate the loss
# loss = log(sum(exp(x[i]))) - x[class] # loss = log(sum(exp(x[i]))) - x[class]
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) num_non_zero = torch.sum(loss != 0.0)
ctx.inv_num_non_zero = 1.0 / num_non_zero
loss = torch.sum(loss).div_(num_non_zero)
# calculate the softmax # calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.save_for_backward(exp_logits, mask, masked_target_1d)
return loss return loss
...@@ -89,6 +92,7 @@ class DistCrossEntropy(Function): ...@@ -89,6 +92,7 @@ class DistCrossEntropy(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# retrieve the saved tensors # retrieve the saved tensors
grad_output = grad_output * ctx.inv_num_non_zero
exp_logits, mask, masked_target_1d = ctx.saved_tensors exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad # use exp logits as the input grad
...@@ -100,7 +104,7 @@ class DistCrossEntropy(Function): ...@@ -100,7 +104,7 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1)) grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None return grad_logits, None, None, None
def cross_entropy_1d( def cross_entropy_1d(
......
...@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm):
) )
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm": if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0] normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon eps = module.variance_epsilon
elementwise_affine = True elementwise_affine = True
......
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.models.falcon.modeling_falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
build_alibi_tensor,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_falcon_alibi_tensor(
self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
if dist.is_initialized():
world_size = dist.get_world_size(process_group)
num_heads = num_heads * world_size
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(
1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size(process_group))
offset = dist.get_rank(process_group) * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset : num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
return build_falcon_alibi_tensor
def get_tp_falcon_decoder_layer_forward():
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, dropout_add
def forward(
self: FalconDecoderLayer,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attn_outputs = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)
outputs = attn_outputs[1:]
# MLP.
mlp_output = self.mlp(mlp_layernorm_out)
if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output = mlp_output + attention_output
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
return forward
def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
def forward(
self: FalconAttention,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
if alibi is not None:
attention_mask_float = (
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
tgt_len = key_layer_.size()[1]
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
context_layer = me_attention(
query_layer_,
key_layer_,
value_layer_,
attn_bias=attention_mask_float,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
batch_size, seq_length, _, _ = context_layer.shape
context_layer = context_layer.reshape(batch_size, seq_length, -1)
output_tensor = self.dense(context_layer)
return output_tensor, present
return forward
class FalconPipelineForwards:
"""
This class serves as a micro library for falcon pipeline forwards.
"""
@staticmethod
def falcon_model_forward(
self: FalconModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
else:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# case: First stage of training
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if stage_manager.is_last_stage():
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
else:
# always return dict for imediate stage
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_causal_lm_forward(
self: FalconForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_sequence_classification_forward(
self: FalconForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
batch_size = hidden_states.shape[0]
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_token_classification_forward(
self: FalconForTokenClassification,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def falcon_for_question_answering_forward(
self: FalconForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
outputs = FalconPipelineForwards.falcon_model_forward(
self.transformer,
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.gptj.modeling_gptj import (
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
apply_rotary_pos_emb,
get_embed_positions,
)
from transformers.utils import is_torch_fx_proxy, logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
class GPTJPipelineForwards:
"""
This class serves as a micro library for forward function substitution of GPTJ models
under pipeline setting.
"""
@staticmethod
def gptj_model_forward(
self: GPTJModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, BaseModelOutputWithPast]:
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward.
# Please refer to original code of transformers for more details.
# GPTJ has no cross attention in comparison to GPT2
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
input_shape = input_ids.size()
input_ids = input_ids.view(-1, seq_length)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# position id to be asssigned not just for the first stage for attn input
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
else:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if stage_manager.is_first_stage():
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=None,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
else:
# always return dict for intermediate stage
return {"hidden_states": hidden_states}
@staticmethod
def gptj_causallm_model_forward(
self: GPTJForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForCausalLM.forward.
# Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = GPTJPipelineForwards.gptj_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
# If not at the last stage, return hidden_states as in GPTJModel
if not stage_manager.is_last_stage():
return {"hidden_states": transformer_outputs["hidden_states"]}
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def gptj_for_sequence_classification_forward(
self: GPTJForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward.
# Please refer to original code of transformers for more details.
"""
logger = logging.get_logger(__name__)
if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
else:
batch_size, _ = hidden_states.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = GPTJPipelineForwards.gptj_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
# If not at the last stage, return hidden_states as in GPTJModel
if not stage_manager.is_last_stage():
return {"hidden_states": transformer_outputs["hidden_states"]}
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def gptj_for_question_answering_forward(
self: GPTJForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering.forward.
# Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = GPTJPipelineForwards.gptj_model_forward(
self.transformer,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
# If not at the last stage, return hidden_states as in GPTJModel
if not stage_manager.is_last_stage():
return {"hidden_states": outputs["hidden_states"]}
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_gptj_flash_attention_forward():
from transformers.models.gptj.modeling_gptj import GPTJAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary or len(tensor.shape) in [4, 5]:
return tensor
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def forward(
self: GPTJAttention,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = split_heads(query, self.num_attention_heads, self.head_dim, True)
key = split_heads(key, self.num_attention_heads, self.head_dim, True)
value = split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
# key = key.permute(0, 2, 1, 3)
# query = query.permute(0, 2, 1, 3)
key = key.to(dtype=value.dtype) # fp16 compatability
query = query.to(dtype=value.dtype)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache is True:
present = (key, value)
else:
present = None
# use AttnMaskType and ColoAttention
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
# use coloattention
scale = value.size(-1) ** -0.5
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale
)
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)
return outputs # a, present, (attentions)
return forward
def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
...@@ -12,6 +13,9 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS ...@@ -12,6 +13,9 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
try: try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
...@@ -42,6 +46,7 @@ class LlamaPipelineForwards: ...@@ -42,6 +46,7 @@ class LlamaPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -200,6 +205,7 @@ class LlamaPipelineForwards: ...@@ -200,6 +205,7 @@ class LlamaPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
Args: Args:
...@@ -269,11 +275,18 @@ class LlamaPipelineForwards: ...@@ -269,11 +275,18 @@ class LlamaPipelineForwards:
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1) shift_labels = shift_labels.view(-1)
# Enable model parallelism # Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device) shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels) if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
...@@ -306,6 +319,7 @@ class LlamaPipelineForwards: ...@@ -306,6 +319,7 @@ class LlamaPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...@@ -403,7 +417,7 @@ class LlamaPipelineForwards: ...@@ -403,7 +417,7 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_llama_flash_attention_forward(): def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
...@@ -459,14 +473,13 @@ def get_llama_flash_attention_forward(): ...@@ -459,14 +473,13 @@ def get_llama_flash_attention_forward():
flash_attention_mask = None flash_attention_mask = None
attn_mask_type = AttnMaskType.causal attn_mask_type = AttnMaskType.causal
if attention_mask != None: if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError( raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask): attn_mask_type = AttnMaskType.paddedcausal
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention( attn_output = attention(
...@@ -483,3 +496,108 @@ def get_llama_flash_attention_forward(): ...@@ -483,3 +496,108 @@ def get_llama_flash_attention_forward():
return attn_output, None, past_key_value return attn_output, None, past_key_value
return forward return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import LlamaForCausalLM
def forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward
from typing import Optional, Tuple
import torch
def get_mistral_flash_attention_forward():
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: MistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = (
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
...@@ -85,6 +85,17 @@ _POLICY_LIST = { ...@@ -85,6 +85,17 @@ _POLICY_LIST = {
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(
file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"
), ),
# GPTJ
"transformers.models.gptj.modeling_gptj.GPTJModel": PolicyLocation(file_name="gptj", class_name="GPTJModelPolicy"),
"transformers.models.gptj.modeling_gptj.GPTJForCausalLM": PolicyLocation(
file_name="gptj", class_name="GPTJForCausalLMPolicy"
),
"transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering": PolicyLocation(
file_name="gptj", class_name="GPTJForQuestionAnsweringPolicy"
),
"transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification": PolicyLocation(
file_name="gptj", class_name="GPTJForSequenceClassificationPolicy"
),
# ViT # ViT
"transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
"transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation(
...@@ -146,6 +157,31 @@ _POLICY_LIST = { ...@@ -146,6 +157,31 @@ _POLICY_LIST = {
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
), ),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForCausalLM": PolicyLocation(
file_name="falcon", class_name="FalconForCausalLMPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForSequenceClassification": PolicyLocation(
file_name="falcon", class_name="FalconForSequenceClassificationPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForTokenClassification": PolicyLocation(
file_name="falcon", class_name="FalconForTokenClassificationPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation(
file_name="mistral", class_name="MistralForCausalLMPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
...@@ -214,13 +214,32 @@ class Policy(ABC): ...@@ -214,13 +214,32 @@ class Policy(ABC):
return layers_per_stage return layers_per_stage
@staticmethod @staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: def get_stage_index(
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
""" """
get the start index and end index of layers for each stage. Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
""" """
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
start_idx = num_layers_per_stage_accumulated[stage] stage_indices = []
end_idx = num_layers_per_stage_accumulated[stage + 1] for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
return [start_idx, end_idx] return stage_indices[0] if num_model_chunks == 1 else stage_indices
...@@ -21,7 +21,7 @@ __all__ = [ ...@@ -21,7 +21,7 @@ __all__ = [
"BertPolicy", "BertPolicy",
"BertModelPolicy", "BertModelPolicy",
"BertForPreTrainingPolicy", "BertForPreTrainingPolicy",
"BertLMdHeadModelPolicy", "BertLMHeadModelPolicy",
"BertForMaskedLMPolicy", "BertForMaskedLMPolicy",
"BertForNextSentencePredictionPolicy", "BertForNextSentencePredictionPolicy",
"BertForSequenceClassificationPolicy", "BertForSequenceClassificationPolicy",
...@@ -249,15 +249,34 @@ class BertPolicy(Policy): ...@@ -249,15 +249,34 @@ class BertPolicy(Policy):
return self.model return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface """
to customized forward method, and add this changing to policy.""" If under pipeline parallel setting, replacing the original forward method of huggingface
if self.pipeline_stage_manager: to customized forward method, and add this changing to policy.
stage_manager = self.pipeline_stage_manager """
if self.model.__class__.__name__ == "BertModel": if self.pipeline_stage_manager is None:
module = self.model return
else:
module = self.model.bert
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "BertModel":
module = self.model
else:
module = self.model.bert
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
...@@ -265,11 +284,8 @@ class BertPolicy(Policy): ...@@ -265,11 +284,8 @@ class BertPolicy(Policy):
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
) )
} }
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
...@@ -282,13 +298,32 @@ class BertPolicy(Policy): ...@@ -282,13 +298,32 @@ class BertPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_interleave:
if stage_manager.is_first_stage(): assert stage_manager.num_model_chunks is not None
held_layers.append(module.embeddings) layers_per_stage = self.distribute_layers(
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
held_layers.extend(module.encoder.layer[start_idx:end_idx]) )
if stage_manager.is_last_stage(): stage_indices = Policy.get_stage_index(
held_layers.append(module.pooler) layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.pooler)
else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
return held_layers return held_layers
...@@ -335,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy): ...@@ -335,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
"""Get pipeline layers for current stage""" """Get pipeline layers for current stage"""
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -374,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy): ...@@ -374,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -412,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -412,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -464,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): ...@@ -464,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
...@@ -508,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy): ...@@ -508,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
...@@ -539,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy): ...@@ -539,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -582,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy): ...@@ -582,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
...@@ -612,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy): ...@@ -612,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs) held_layers.append(self.model.qa_outputs)
return held_layers return held_layers
......
...@@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe ...@@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
class BloomPolicy(Policy): class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Bloom model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment