Unverified Commit 4fa689fc authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
parent af952673
import ctypes import ctypes
import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
...@@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
...@@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
if os.getenv("NCCL_BUFFSIZE") is None:
logger = get_dist_logger()
logger.warning(
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
)
os.environ["NCCL_BUFFSIZE"] = "134217728"
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert ( assert (
......
This diff is collapsed.
...@@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList ...@@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
...@@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule):
assert ( assert (
num_microbatch is not None or microbatch_size is not None num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided" ), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatch = num_microbatch self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
...@@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule):
self.batch: Any self.batch: Any
self.batch_size: int self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int] self.microbatch_offset: List[int]
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule):
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None: if self.num_microbatch is not None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
...@@ -106,10 +119,11 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -106,10 +119,11 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
input_tensor = None if not self.stage_manager.is_first_stage():
else: input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
input_tensor = self.comm.recv_forward(prev_rank) if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
...@@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor_grad = None if not self.stage_manager.is_last_stage():
else: output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
output_tensor_grad = self.comm.recv_backward(next_rank) if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
...@@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule):
output_object (Any): Object to be sent. output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
self.comm.send_forward(output_object, next_rank) if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B. For interleaved 1F1B.
...@@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule):
input_object (Any): Object to be sent. input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
self.comm.send_backward(input_object, prev_rank) if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.send_backward_recv_forward(
output_object,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
def forward_step( def forward_step(
self, self,
...@@ -180,7 +233,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -180,7 +233,7 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
self.stage_manager.model_chunk_id = model_chunk_id with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList): if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else: else:
...@@ -188,9 +241,8 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -188,9 +241,8 @@ class InterleavedSchedule(PipelineSchedule):
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs) output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
self.stage_manager.model_chunk_id = None
if self.stage_manager.is_last_stage(model_chunk_id): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
...@@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule):
Returns: Returns:
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
# TODO: handle arbitrary batch size when forward_only == True self.forward_only = not torch.is_grad_enabled()
forward_only = not torch.is_grad_enabled()
if optimizer is None: if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward." assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks num_microbatch = self.num_microbatch * self.num_model_chunks
if forward_only: if self.forward_only:
num_warmup_microbatch = num_microbatch num_warmup_microbatch = num_microbatch
else: else:
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
...@@ -288,42 +339,28 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -288,42 +339,28 @@ class InterleavedSchedule(PipelineSchedule):
input_objs = None input_objs = None
output_objs = None output_objs = None
if not forward_only: if not self.forward_only:
input_objs = [[] for _ in range(self.num_model_chunks)] input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
if return_loss and self.stage_manager.is_last_stage(-1):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
# for ranks except the first one, get into recv state accum_loss = torch.zeros(1, device=get_current_device())
input_obj = self.recv_forward(0)
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatch): for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True) model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# recv first on first rank to avoid sending or receiving at the same time
if self.stage_manager.is_first_stage(-1):
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj) if not self.forward_only:
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch: if num_microbatch_remaining > 0:
break model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule):
last_iteration = i == num_microbatch_remaining - 1 last_iteration = i == num_microbatch_remaining - 1
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only: if self.forward_only:
self.send_forward(model_chunk_id, output_obj)
if not last_iteration: if not last_iteration:
input_obj = self.recv_forward(model_chunk_id) input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)
else: else:
self.send_forward(model_chunk_id, output_obj) self.send_forward(model_chunk_id, output_obj)
...@@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule):
# backward # backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if last_iteration: if not last_iteration:
input_obj = None
else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id) input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not self.forward_only:
for i in range(num_microbatch_remaining, num_microbatch): for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False) model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0) input_obj = input_objs[model_chunk_id].pop(0)
...@@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad) self.send_backward(model_chunk_id, input_obj_grad)
if not forward_only: if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
......
...@@ -7,7 +7,7 @@ from torch.nn import Module ...@@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
...@@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided" ), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager) self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
...@@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter) batch = next(data_iter)
if device is not None: if device is not None:
batch = tree_map(partial(to_device, device=device), batch) batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = 0 self.microbatch_offset = 0
if not self._use_microbatch_size: if not self._use_microbatch_size:
assert ( assert (
...@@ -92,10 +106,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -92,10 +106,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
Any: The input tensor or input tensor list. Any: The input tensor or input tensor list.
""" """
if self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = None input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
else: if self.metadata_recv_forward is None:
input_tensor = self.comm.recv_forward(prev_rank) self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor return input_tensor
...@@ -109,10 +123,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -109,10 +123,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
Any: The input gradient tensor or gradient tensor list. Any: The input gradient tensor or gradient tensor list.
""" """
if self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = None output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
else: if self.metadata_recv_backward is None:
output_tensor_grad = self.comm.recv_backward(next_rank) self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
...@@ -125,7 +139,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -125,7 +139,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank) self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
...@@ -136,18 +163,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -136,18 +163,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank) output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None: return output_tensor_grad
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
...@@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor. prev_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank) input_tensor = self.comm.send_backward_recv_forward(
output_object,
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: prev_rank,
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline. send_metadata=self.send_metadata_backward,
For 1F1B. metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
Args: return input_tensor
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
def forward_step( def forward_step(
self, self,
...@@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns: Returns:
dict: A dict with keys: 'loss' and 'outputs'. dict: A dict with keys: 'loss' and 'outputs'.
""" """
forward_only = not torch.is_grad_enabled()
self.forward_only = not torch.is_grad_enabled()
if optimizer is None: if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward." assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter) self.load_batch(data_iter)
...@@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_objs = None input_objs = None
output_objs = None output_objs = None
if not forward_only: if not self.forward_only:
input_objs = [] input_objs = []
output_objs = [] output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None accum_loss = None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
else: outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = self.recv_forward() input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj) self.send_forward(output_obj)
if not forward_only: if not self.forward_only:
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
...@@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
if self.forward_only:
self.send_forward(output_obj) self.send_forward(output_obj)
if not last_iteration: if not last_iteration:
input_obj = self.recv_forward() input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
else:
output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
...@@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration: if last_iteration:
input_obj = None
else:
input_obj = self.recv_forward()
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(input_obj_grad)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not self.forward_only:
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
...@@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad) self.send_backward(input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
if isinstance(model, ModelWrapper): if isinstance(model, ModelWrapper):
model = model.unwrap() model = model.unwrap()
......
import contextlib
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch.distributed as dist import torch.distributed as dist
...@@ -68,45 +69,39 @@ class PipelineStageManager: ...@@ -68,45 +69,39 @@ class PipelineStageManager:
# for shardformer, hold model chunk id # for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None self.model_chunk_id: Optional[int] = None
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage. """Is the current stage the first stage.
NOTE: NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() 2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
Returns: Returns:
bool: Whether the current stage is the first stage. bool: Whether the current stage is the first stage.
""" """
if self.is_interleave and model_chunk_id is None: assert isinstance(ignore_chunk, bool)
model_chunk_id = self.model_chunk_id assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
assert self.is_interleave ^ ( if not self.is_interleave or ignore_chunk:
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == 0 return self.stage == 0
else: else:
return self.stage == 0 and model_chunk_id == 0 return self.stage == 0 and self.model_chunk_id == 0
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: def is_last_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the last stage. """Is the current stage the last stage.
NOTE: NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() 2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
Returns: Returns:
bool: Whether the current stage is the last stage. bool: Whether the current stage is the last stage.
""" """
if self.is_interleave and model_chunk_id is None: assert isinstance(ignore_chunk, bool)
model_chunk_id = self.model_chunk_id assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
assert self.is_interleave ^ ( if not self.is_interleave or ignore_chunk:
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == self.num_stages - 1 return self.stage == self.num_stages - 1
else: else:
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property @property
def num_stages(self) -> int: def num_stages(self) -> int:
...@@ -174,3 +169,10 @@ class PipelineStageManager: ...@@ -174,3 +169,10 @@ class PipelineStageManager:
ProcessGroup: Process group of the given stages. ProcessGroup: Process group of the given stages.
""" """
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
@contextlib.contextmanager
def switch_model_chunk_id(self, model_chunk_id: int):
old_model_chunk_id = self.model_chunk_id
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id
...@@ -309,11 +309,11 @@ class BertPolicy(Policy): ...@@ -309,11 +309,11 @@ class BertPolicy(Policy):
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages, num_stages=stage_manager.num_stages,
) )
if stage_manager.is_first_stage(-1): if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings) held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx]) held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(-1): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.pooler) held_layers.append(module.pooler)
else: else:
...@@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy): ...@@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
"""Get pipeline layers for current stage""" """Get pipeline layers for current stage"""
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy): ...@@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): ...@@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
...@@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy): ...@@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
...@@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy): ...@@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls) held_layers.append(self.model.cls)
return held_layers return held_layers
...@@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy): ...@@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout) held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier) held_layers.append(self.model.classifier)
return held_layers return held_layers
...@@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy): ...@@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
""" """
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs) held_layers.append(self.model.qa_outputs)
return held_layers return held_layers
......
...@@ -8,7 +8,11 @@ from torch.nn import Module ...@@ -8,7 +8,11 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy from ..modeling.llama import (
LlamaPipelineForwards,
get_llama_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
...@@ -140,21 +144,42 @@ class LlamaPolicy(Policy): ...@@ -140,21 +144,42 @@ class LlamaPolicy(Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface """If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy.""" to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager: if self.pipeline_stage_manager is None:
return
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "LlamaModel": if self.model.__class__.__name__ == "LlamaModel":
module = self.model module = self.model
else: else:
module = self.model.model module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)} method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls
) )
return self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
...@@ -167,6 +192,25 @@ class LlamaPolicy(Policy): ...@@ -167,6 +192,25 @@ class LlamaPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)
else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
...@@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ...@@ -211,11 +255,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
suffix="lm_head", target_module=Linear1D_Col
)
], ],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
) )
} }
policy.update(new_item) policy.update(new_item)
...@@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ...@@ -232,7 +274,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
...@@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ...@@ -285,7 +327,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score) held_layers.append(self.model.score)
return held_layers return held_layers
......
...@@ -57,9 +57,7 @@ def evaluate_model( ...@@ -57,9 +57,7 @@ def evaluate_model(
def evaluate_subset(dataloader: DataLoader): def evaluate_subset(dataloader: DataLoader):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
None if not booster.plugin.stage_manager.is_interleave else -1
)
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader: for batch in dataloader:
...@@ -136,9 +134,7 @@ def train_epoch( ...@@ -136,9 +134,7 @@ def train_epoch(
coordinator: DistCoordinator, coordinator: DistCoordinator,
): ):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
None if not booster.plugin.stage_manager.is_interleave else -1
)
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
total_step = len(train_dataloader) total_step = len(train_dataloader)
......
...@@ -133,7 +133,9 @@ def main(): ...@@ -133,7 +133,9 @@ def main():
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
pp_style="interleaved",
zero_stage=args.zero, zero_stage=args.zero,
num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs, num_microbatches=args.mbs,
precision="bf16", precision="bf16",
......
import warnings
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
WORLD_SIZE = 2
def check_p2p_communication(): def check_p2p_communication():
pg_mesh = ProcessGroupMesh(2) pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0) stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager) p2p = PipelineP2PCommunication(stage_manager)
rank = dist.get_rank() rank = dist.get_rank()
tensor = torch.ones(1, device=get_current_device()) tensor = torch.ones(1, device=get_current_device())
data = [
"tensor",
tensor,
[tensor],
{"tensor": tensor},
]
if rank == 0: if rank == 0:
p2p.send_forward(tensor) for obj in data:
p2p.send_forward([tensor]) p2p.send_forward(obj)
p2p.send_forward({"tensor": tensor}) for i in range(len(data)):
else: recv_obj = p2p.send_forward_recv_backward(data[i])
obj = p2p.recv_forward() assert recv_obj == data[-(i + 1)]
assert torch.equal(obj, tensor) elif rank == 1:
obj = p2p.recv_forward() for obj in data:
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) recv_obj = p2p.recv_forward()
obj = p2p.recv_forward() assert recv_obj == obj
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1: if rank == 1:
p2p.send_backward(tensor) for obj in data:
p2p.send_backward([tensor]) p2p.send_backward(obj)
p2p.send_backward({"tensor": tensor}) for i in range(len(data)):
else: recv_obj = p2p.send_backward_recv_forward(data[i])
obj = p2p.recv_backward() assert recv_obj == data[-(i + 1)]
assert torch.equal(obj, tensor) elif rank == 0:
obj = p2p.recv_backward() for obj in data:
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) recv_obj = p2p.recv_backward()
obj = p2p.recv_backward() assert recv_obj == obj
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) for i in range(len(data)):
recv_obj = p2p.recv_backward()
p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i]
warnings.filterwarnings("error")
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=comm_metadata,
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
...@@ -52,7 +85,7 @@ def run_dist(rank, world_size, port): ...@@ -52,7 +85,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pipeline_p2p(): def test_pipeline_p2p():
spawn(run_dist, 2) spawn(run_dist, WORLD_SIZE)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -37,9 +37,10 @@ def pp_linear_fwd( ...@@ -37,9 +37,10 @@ def pp_linear_fwd(
stage_mgr: PipelineStageManager = None, stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None, model_chunk_id: int = None,
): ):
if stage_mgr.is_first_stage(model_chunk_id): with stage_mgr.switch_model_chunk_id(model_chunk_id):
if stage_mgr.is_first_stage():
return {"input_obj": forward(data)} return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage(model_chunk_id): elif stage_mgr.is_last_stage():
return forward(input_obj) return forward(input_obj)
else: else:
return {"input_obj": forward(input_obj)} return {"input_obj": forward(input_obj)}
...@@ -107,7 +108,7 @@ def run_pp( ...@@ -107,7 +108,7 @@ def run_pp(
) )
# check loss # check loss
if stage_manager.is_last_stage(-1): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients # check gradients
...@@ -119,6 +120,7 @@ def run_pp( ...@@ -119,6 +120,7 @@ def run_pp(
# step # step
torch_optimizer.step() torch_optimizer.step()
pp_optimizer.step() pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param # check updated param
for i in range(num_model_chunk): for i in range(num_model_chunk):
...@@ -126,6 +128,24 @@ def run_pp( ...@@ -126,6 +128,24 @@ def run_pp(
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12]) @pytest.mark.parametrize("num_microbatch", [4, 12])
......
...@@ -4,6 +4,7 @@ from types import MethodType ...@@ -4,6 +4,7 @@ from types import MethodType
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
...@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
DIM = 8
NUM_LAYER = 8
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super().__init__()
self.linear1 = nn.Linear(4, 8) self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
self.linear2 = nn.Linear(8, 4)
def forward(self, x): def forward(self, x):
x = self.linear1(x) for layer in self.layers:
x = self.linear2(x) x = layer(x)
return x return x
def pp_linear_fwd( def pp_linear_fwd(
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
): ):
if stage_mgr.is_first_stage(): if stage_mgr.is_first_stage():
return {"input_obj": forward(data)} return {"input_obj": forward(data)}
...@@ -38,34 +44,45 @@ def pp_linear_fwd( ...@@ -38,34 +44,45 @@ def pp_linear_fwd(
return {"input_obj": forward(input_obj)} return {"input_obj": forward(input_obj)}
def examine_pp(): def examine_pp(num_microbatch: int, batch_size: int):
""" """
This test is to examine the correctness of 1F1B, compared with torch. This test is to examine the correctness of 1F1B, compared with torch.
Be aware it contains some hardcodes. Be aware it contains some hardcodes.
""" """
world_size = torch.distributed.get_world_size() world_size = dist.get_world_size()
local_rank = torch.distributed.get_rank() dist.get_rank()
seed_all(1453) seed_all(1453)
NUM_MICRO_BATCHS = 4
BATCH_SIZE = 4
# create models # create models
torch_model = MlpModel().cuda() torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda() pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(world_size)
pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
rank = dist.get_rank()
for idx, (_, sub_model) in enumerate(pp_model.named_children()): sharded_model = torch.nn.ModuleList()
if idx % (world_size) == local_rank: num_local_layer = NUM_LAYER // world_size
sharded_model = sub_model.cuda() for idx, sub_model in enumerate(pp_model.layers):
if idx // num_local_layer == rank:
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_local_layer
def custom_fwd(self, x):
for layer in self._modules.values():
x = layer(x)
return x
sharded_model._forward = sharded_model.forward sharded_model._forward = MethodType(custom_fwd, sharded_model)
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) sharded_model.forward = MethodType(
partial(
pp_linear_fwd,
stage_mgr=stage_manager,
),
sharded_model._forward,
)
# create optimizer # create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
...@@ -73,19 +90,15 @@ def examine_pp(): ...@@ -73,19 +90,15 @@ def examine_pp():
# create # create
seed_all(1453) seed_all(1453)
if stage_manager.is_first_stage(): input_list = [torch.rand(batch_size, DIM).cuda()]
input_list = [torch.rand(BATCH_SIZE, 4).cuda()] dist.all_reduce(input_list[0])
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x) criterion = lambda x, *arg, **kwargs: (x * x).mean()
# forward and backward # forward and backward
torch_output = torch_model(input_list[0]) torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
) )
...@@ -95,34 +108,66 @@ def examine_pp(): ...@@ -95,34 +108,66 @@ def examine_pp():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients # check gradients
torch_grad = [] for i in range(len(sharded_model)):
for torch_p in torch_model.parameters(): idx = rank * num_local_layer + i
torch_grad.append(torch_p.grad.data) assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
# step # step
torch_optimizer.step() torch_optimizer.step()
pp_optimizer.step() pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param # check updated param
torch_param = [] for i in range(len(sharded_model)):
for torch_p in torch_model.parameters(): idx = rank * num_local_layer + i
torch_param.append(torch_p.data) assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
# forward only
with torch.no_grad():
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
def run_dist(rank, world_size, port): pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp() examine_pp(num_microbatch, batch_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pp(): def test_pp(num_microbatch: int, batch_size: int, world_size: int):
spawn(run_dist, 2) assert NUM_LAYER % world_size == 0
spawn(
run_dist,
world_size,
num_microbatch=num_microbatch,
batch_size=batch_size,
)
if __name__ == "__main__": if __name__ == "__main__":
test_pp() test_pp(num_microbatch=4, batch_size=4, world_size=4)
...@@ -203,7 +203,7 @@ def check_output_hidden_state( ...@@ -203,7 +203,7 @@ def check_output_hidden_state(
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(): if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
...@@ -229,6 +229,10 @@ def check_weight( ...@@ -229,6 +229,10 @@ def check_weight(
org_weight = getattr_(org_model, suffix).weight org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [ sharded_weight_list = [
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
......
...@@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"] norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"] col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
...@@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check( col_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
) )
...@@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
...@@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 1e-3 atol, rtol = 5e-3, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check grads # check grads
check_all_grad_tensors(grads_to_check) check_all_grad_tensors(grads_to_check)
...@@ -183,6 +184,17 @@ def run_bert_test(test_config): ...@@ -183,6 +184,17 @@ def run_bert_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
], ],
) )
def run_bert_3d_test(test_config): def run_bert_3d_test(test_config):
......
...@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-4 atol, rtol = 1e-6, 1e-4
else: else:
...@@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
...@@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights # check weights
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
...@@ -179,6 +179,17 @@ def run_llama_test(test_config): ...@@ -179,6 +179,17 @@ def run_llama_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
], ],
) )
def run_llama_3d_test(test_config): def run_llama_3d_test(test_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment