Unverified Commit 4fa689fc authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
parent af952673
import ctypes
import os
import random
from contextlib import contextmanager
from functools import partial
......@@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
......@@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
if os.getenv("NCCL_BUFFSIZE") is None:
logger = get_dist_logger()
logger.warning(
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
)
os.environ["NCCL_BUFFSIZE"] = "134217728"
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
......
This diff is collapsed.
......@@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
......@@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule):
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
......@@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule):
self.batch: Any
self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int]
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
......@@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule):
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
......@@ -106,12 +119,13 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage(model_chunk_id):
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
return input_tensor
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
......@@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage(model_chunk_id):
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
......@@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule):
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage(model_chunk_id):
self.comm.send_forward(output_object, next_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
......@@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule):
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage(model_chunk_id):
self.comm.send_backward(input_object, prev_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.send_backward_recv_forward(
output_object,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
def forward_step(
self,
......@@ -180,25 +233,24 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
self.stage_manager.model_chunk_id = model_chunk_id
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else:
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
self.stage_manager.model_chunk_id = None
if self.stage_manager.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else:
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step(
self,
......@@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
# TODO: handle arbitrary batch size when forward_only == True
forward_only = not torch.is_grad_enabled()
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
if forward_only:
if self.forward_only:
num_warmup_microbatch = num_microbatch
else:
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
......@@ -288,43 +339,29 @@ class InterleavedSchedule(PipelineSchedule):
input_objs = None
output_objs = None
if not forward_only:
if not self.forward_only:
input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
if return_loss and self.stage_manager.is_last_stage(-1):
accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# for ranks except the first one, get into recv state
input_obj = self.recv_forward(0)
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# recv first on first rank to avoid sending or receiving at the same time
if self.stage_manager.is_first_stage(-1):
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch:
break
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not self.forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
# Run 1F1B in steady state.
for i in range(num_microbatch_remaining):
......@@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule):
last_iteration = i == num_microbatch_remaining - 1
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(model_chunk_id, output_obj)
if self.forward_only:
if not last_iteration:
input_obj = self.recv_forward(model_chunk_id)
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
......@@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule):
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if last_iteration:
input_obj = None
else:
if not last_iteration:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
if not self.forward_only:
for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0)
......@@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if not forward_only:
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
......
......@@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
......@@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
......@@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert (
......@@ -92,12 +106,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
......@@ -109,12 +123,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
......@@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
......@@ -147,34 +151,50 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank)
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
output_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.send_backward_recv_forward(
output_object,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
def forward_step(
self,
......@@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)
......@@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_objs = None
output_objs = None
if not forward_only:
if not self.forward_only:
input_objs = []
output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if not forward_only:
if not self.forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
......@@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
if self.forward_only:
self.send_forward(output_obj)
if not last_iteration:
input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
else:
output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
......@@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
self.send_backward(input_obj_grad)
else:
input_obj = self.recv_forward()
self.send_backward(input_obj_grad)
input_obj = self.send_backward_recv_forward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
if not self.forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
......@@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
......
import contextlib
from typing import Dict, List, Optional, Tuple
import torch.distributed as dist
......@@ -68,45 +69,39 @@ class PipelineStageManager:
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage.
NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device()
2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
Returns:
bool: Whether the current stage is the first stage.
"""
if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
assert isinstance(ignore_chunk, bool)
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
if not self.is_interleave or ignore_chunk:
return self.stage == 0
else:
return self.stage == 0 and model_chunk_id == 0
return self.stage == 0 and self.model_chunk_id == 0
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
def is_last_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the last stage.
NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device()
2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
Returns:
bool: Whether the current stage is the last stage.
"""
if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
assert isinstance(ignore_chunk, bool)
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property
def num_stages(self) -> int:
......@@ -174,3 +169,10 @@ class PipelineStageManager:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
@contextlib.contextmanager
def switch_model_chunk_id(self, model_chunk_id: int):
old_model_chunk_id = self.model_chunk_id
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id
......@@ -309,11 +309,11 @@ class BertPolicy(Policy):
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(-1):
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(-1):
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.pooler)
else:
......@@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
"""Get pipeline layers for current stage"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
......@@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
......@@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
......@@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
......@@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
......@@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
......@@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
......@@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs)
return held_layers
......
......@@ -8,7 +8,11 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy
from ..modeling.llama import (
LlamaPipelineForwards,
get_llama_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
......@@ -140,21 +144,42 @@ class LlamaPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "LlamaModel":
module = self.model
else:
module = self.model.model
if self.pipeline_stage_manager is None:
return
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "LlamaModel":
module = self.model
else:
module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)}
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
......@@ -167,13 +192,32 @@ class LlamaPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)
else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
......@@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col
)
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
policy.update(new_item)
......@@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
......@@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
......
......@@ -57,9 +57,7 @@ def evaluate_model(
def evaluate_subset(dataloader: DataLoader):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
None if not booster.plugin.stage_manager.is_interleave else -1
)
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
......@@ -136,9 +134,7 @@ def train_epoch(
coordinator: DistCoordinator,
):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
None if not booster.plugin.stage_manager.is_interleave else -1
)
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
total_step = len(train_dataloader)
......
......@@ -133,7 +133,9 @@ def main():
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
pp_style="interleaved",
zero_stage=args.zero,
num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
precision="bf16",
......
import warnings
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
WORLD_SIZE = 2
def check_p2p_communication():
pg_mesh = ProcessGroupMesh(2)
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager)
rank = dist.get_rank()
tensor = torch.ones(1, device=get_current_device())
data = [
"tensor",
tensor,
[tensor],
{"tensor": tensor},
]
if rank == 0:
p2p.send_forward(tensor)
p2p.send_forward([tensor])
p2p.send_forward({"tensor": tensor})
else:
obj = p2p.recv_forward()
assert torch.equal(obj, tensor)
obj = p2p.recv_forward()
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
obj = p2p.recv_forward()
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
recv_obj = p2p.send_forward_recv_backward(data[i])
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
recv_obj = p2p.recv_forward()
assert recv_obj == obj
for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1:
p2p.send_backward(tensor)
p2p.send_backward([tensor])
p2p.send_backward({"tensor": tensor})
else:
obj = p2p.recv_backward()
assert torch.equal(obj, tensor)
obj = p2p.recv_backward()
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
obj = p2p.recv_backward()
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
recv_obj = p2p.send_backward_recv_forward(data[i])
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
recv_obj = p2p.recv_backward()
assert recv_obj == obj
for i in range(len(data)):
recv_obj = p2p.recv_backward()
p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i]
warnings.filterwarnings("error")
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=comm_metadata,
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)
def run_dist(rank, world_size, port):
......@@ -52,7 +85,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pipeline_p2p():
spawn(run_dist, 2)
spawn(run_dist, WORLD_SIZE)
if __name__ == "__main__":
......
......@@ -37,12 +37,13 @@ def pp_linear_fwd(
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
if stage_mgr.is_first_stage(model_chunk_id):
return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage(model_chunk_id):
return forward(input_obj)
else:
return {"input_obj": forward(input_obj)}
with stage_mgr.switch_model_chunk_id(model_chunk_id):
if stage_mgr.is_first_stage():
return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage():
return forward(input_obj)
else:
return {"input_obj": forward(input_obj)}
def run_pp(
......@@ -107,7 +108,7 @@ def run_pp(
)
# check loss
if stage_manager.is_last_stage(-1):
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients
......@@ -119,6 +120,7 @@ def run_pp(
# step
torch_optimizer.step()
pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param
for i in range(num_model_chunk):
......@@ -126,6 +128,24 @@ def run_pp(
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
......
......@@ -4,6 +4,7 @@ from types import MethodType
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
......@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
DIM = 8
NUM_LAYER = 8
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(4, 8)
self.linear2 = nn.Linear(8, 4)
super().__init__()
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
for layer in self.layers:
x = layer(x)
return x
def pp_linear_fwd(
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None
forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
):
if stage_mgr.is_first_stage():
return {"input_obj": forward(data)}
......@@ -38,34 +44,45 @@ def pp_linear_fwd(
return {"input_obj": forward(input_obj)}
def examine_pp():
def examine_pp(num_microbatch: int, batch_size: int):
"""
This test is to examine the correctness of 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
world_size = dist.get_world_size()
dist.get_rank()
seed_all(1453)
NUM_MICRO_BATCHS = 4
BATCH_SIZE = 4
# create models
torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
sharded_model = sub_model.cuda()
pg_mesh = ProcessGroupMesh(world_size)
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
rank = dist.get_rank()
sharded_model = torch.nn.ModuleList()
num_local_layer = NUM_LAYER // world_size
for idx, sub_model in enumerate(pp_model.layers):
if idx // num_local_layer == rank:
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_local_layer
def custom_fwd(self, x):
for layer in self._modules.values():
x = layer(x)
return x
sharded_model._forward = sharded_model.forward
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward)
sharded_model._forward = MethodType(custom_fwd, sharded_model)
sharded_model.forward = MethodType(
partial(
pp_linear_fwd,
stage_mgr=stage_manager,
),
sharded_model._forward,
)
# create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
......@@ -73,19 +90,15 @@ def examine_pp():
# create
seed_all(1453)
if stage_manager.is_first_stage():
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
input_list = [torch.rand(batch_size, DIM).cuda()]
dist.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x)
criterion = lambda x, *arg, **kwargs: (x * x).mean()
# forward and backward
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _)
torch_loss = criterion(torch_output)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
......@@ -95,34 +108,66 @@ def examine_pp():
assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients
torch_grad = []
for torch_p in torch_model.parameters():
torch_grad.append(torch_p.grad.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param
torch_param = []
for torch_p in torch_model.parameters():
torch_param.append(torch_p.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
def run_dist(rank, world_size, port):
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp()
examine_pp(num_microbatch, batch_size)
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_pp():
spawn(run_dist, 2)
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
assert NUM_LAYER % world_size == 0
spawn(
run_dist,
world_size,
num_microbatch=num_microbatch,
batch_size=batch_size,
)
if __name__ == "__main__":
test_pp()
test_pp(num_microbatch=4, batch_size=4, world_size=4)
......@@ -203,7 +203,7 @@ def check_output_hidden_state(
):
org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
......@@ -229,6 +229,10 @@ def check_weight(
org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
......
......@@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
......@@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
......@@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
......@@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check grads
check_all_grad_tensors(grads_to_check)
......@@ -183,6 +184,17 @@ def run_bert_test(test_config):
"zero_stage": 1,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
def run_bert_3d_test(test_config):
......
......@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-4
else:
......@@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
......@@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
......@@ -179,6 +179,17 @@ def run_llama_test(test_config):
"zero_stage": 1,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
def run_llama_3d_test(test_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment