Unverified Commit 0e25fcc4 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Cherry-pick Megatron-LM's changes in pipeline model parallel for T5 (#1232)

* update parallel_state

* update pipeline common funcs - forward_step and backward_step

* update pipelining w/o interleaving

* type hint

* merge utils into without_interleaving

Motivation: functions in utils are only used by
forward_backward_pipelining_without_interleaving

* fix handling of `model_type`

* fix import of DDP

* update set_input_tensor method

* fix

* cosmetic

* update model

* refactor pipeline test scripts
parent ab7af058
...@@ -28,3 +28,8 @@ class AttnType(enum.Enum): ...@@ -28,3 +28,8 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum): class AttnMaskType(enum.Enum):
padding = 1 padding = 1
causal = 2 causal = 2
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
...@@ -13,12 +13,10 @@ ...@@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Model and data parallel groups.""" """Model and data parallel groups."""
from typing import Tuple from typing import Tuple, Optional
import torch import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from apex.transformer.utils import ensure_divisibility from apex.transformer.utils import ensure_divisibility
...@@ -35,6 +33,7 @@ _DATA_PARALLEL_GROUP = None ...@@ -35,6 +33,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -56,14 +55,19 @@ def is_unitialized(): ...@@ -56,14 +55,19 @@ def is_unitialized():
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None tensor_model_parallel_size_: int = 1,
): pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None,
) -> None:
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor. tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -106,6 +110,10 @@ def initialize_model_parallel( ...@@ -106,6 +110,10 @@ def initialize_model_parallel(
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -231,6 +239,44 @@ def is_rank_in_embedding_group(ignore_virtual=False): ...@@ -231,6 +239,44 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return False return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
def set_tensor_model_parallel_world_size(world_size): def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size""" """Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
......
...@@ -190,8 +190,8 @@ def recv_forward( ...@@ -190,8 +190,8 @@ def recv_forward(
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
return None return None
if timers is not None: # if timers is not None:
timers("forward-recv").start() # timers("forward-recv").start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -201,8 +201,8 @@ def recv_forward( ...@@ -201,8 +201,8 @@ def recv_forward(
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("forward-recv").stop() # timers("forward-recv").stop()
return input_tensor return input_tensor
...@@ -211,12 +211,12 @@ def recv_backward( ...@@ -211,12 +211,12 @@ def recv_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
): ) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
return None return None
if timers is not None: # if timers is not None:
timers("backward-recv").start() # timers("backward-recv").start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -225,8 +225,8 @@ def recv_backward( ...@@ -225,8 +225,8 @@ def recv_backward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("backward-recv").stop() # timers("backward-recv").stop()
return output_tensor_grad return output_tensor_grad
...@@ -241,8 +241,8 @@ def send_forward( ...@@ -241,8 +241,8 @@ def send_forward(
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
return return
if timers is not None: # if timers is not None:
timers("forward-send").start() # timers("forward-send").start()
_communicate( _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -252,8 +252,8 @@ def send_forward( ...@@ -252,8 +252,8 @@ def send_forward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("forward-send").stop() # timers("forward-send").stop()
def send_backward( def send_backward(
...@@ -266,8 +266,8 @@ def send_backward( ...@@ -266,8 +266,8 @@ def send_backward(
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
return return
if timers is not None: # if timers is not None:
timers("backward-send").start() # timers("backward-send").start()
_communicate( _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -276,8 +276,8 @@ def send_backward( ...@@ -276,8 +276,8 @@ def send_backward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("backward-send").stop() # timers("backward-send").stop()
def send_forward_recv_backward( def send_forward_recv_backward(
...@@ -286,12 +286,12 @@ def send_forward_recv_backward( ...@@ -286,12 +286,12 @@ def send_forward_recv_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> None: ) -> Union[None, torch.Tensor]:
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
return None return None
if timers is not None: # if timers is not None:
timers("forward-send-backward-recv").start() # timers("forward-send-backward-recv").start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -300,8 +300,8 @@ def send_forward_recv_backward( ...@@ -300,8 +300,8 @@ def send_forward_recv_backward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("forward-send-backward-recv").stop() # timers("forward-send-backward-recv").stop()
return output_tensor_grad return output_tensor_grad
...@@ -311,12 +311,12 @@ def send_backward_recv_forward( ...@@ -311,12 +311,12 @@ def send_backward_recv_forward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: ) -> Union[None, torch.Tensor]:
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
return None return None
if timers is not None: # if timers is not None:
timers("backward-send-forward-recv").start() # timers("backward-send-forward-recv").start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -325,8 +325,8 @@ def send_backward_recv_forward( ...@@ -325,8 +325,8 @@ def send_backward_recv_forward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("backward-send-forward-recv").stop() # timers("backward-send-forward-recv").stop()
return input_tensor return input_tensor
...@@ -339,8 +339,8 @@ def send_forward_recv_forward( ...@@ -339,8 +339,8 @@ def send_forward_recv_forward(
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: # if timers is not None:
timers("forward-send-forward-recv").start() # timers("forward-send-forward-recv").start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -349,8 +349,8 @@ def send_forward_recv_forward( ...@@ -349,8 +349,8 @@ def send_forward_recv_forward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("forward-send-forward-recv").stop() # timers("forward-send-forward-recv").stop()
return input_tensor return input_tensor
...@@ -363,8 +363,8 @@ def send_backward_recv_backward( ...@@ -363,8 +363,8 @@ def send_backward_recv_backward(
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: # if timers is not None:
timers("backward-send-backward-recv").start() # timers("backward-send-backward-recv").start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -373,8 +373,8 @@ def send_backward_recv_backward( ...@@ -373,8 +373,8 @@ def send_backward_recv_backward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("backward-send-backward-recv").stop() # timers("backward-send-backward-recv").stop()
return output_tensor_grad return output_tensor_grad
...@@ -387,10 +387,10 @@ def send_forward_backward_recv_forward_backward( ...@@ -387,10 +387,10 @@ def send_forward_backward_recv_forward_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: # if timers is not None:
timers("forward-backward-send-forward-backward-recv").start() # timers("forward-backward-send-forward-backward-recv").start()
input_tensor, output_tensor_grad = _communicate( input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -399,6 +399,6 @@ def send_forward_backward_recv_forward_backward( ...@@ -399,6 +399,6 @@ def send_forward_backward_recv_forward_backward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=_get_current_dtype(dtype),
) )
if timers is not None: # if timers is not None:
timers("forward-backward-send-forward-backward-recv").stop() # timers("forward-backward-send-forward-backward-recv").stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out. # NOTE(mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch import torch
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
...@@ -19,8 +21,8 @@ def build_model( ...@@ -19,8 +21,8 @@ def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True, wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None, virtual_pipeline_model_parallel_size: Optional[int] = None,
*args, *args: Any,
**kwargs **kwargs: Any,
) -> List[torch.nn.Module]: ) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements. """Build the model satisfying pipeline model parallel requirements.
...@@ -110,6 +112,7 @@ def _get_params_for_weight_decay_optimization( ...@@ -110,6 +112,7 @@ def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]: ) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will. Layernorms and biases will have no weight decay but the rest will.
""" """
modules = listify_model(model) modules = listify_model(model)
...@@ -137,13 +140,12 @@ def forward_step( ...@@ -137,13 +140,12 @@ def forward_step(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Batch, batch: Batch,
model: torch.nn.Module, model: torch.nn.Module,
input_tensor: Optional[torch.Tensor], input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor], losses_reduced: List[torch.Tensor],
): ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used.
passed-in input_tensor is used.
Returns output tensor. Returns output tensor.
...@@ -161,12 +163,16 @@ def forward_step( ...@@ -161,12 +163,16 @@ def forward_step(
# timers = get_timers() # timers = get_timers()
# timers("forward-compute").start() # timers("forward-compute").start()
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
model_type = get_model_type(unwrapped_model)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`. # NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA # See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`. # for the details of `set_input_tensor`.
unwrap_output_tensor = not isinstance(input_tensor, list)
if unwrap_output_tensor:
input_tensor = [input_tensor]
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model) output_tensor, loss_func = forward_step_func(batch, model)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
...@@ -174,14 +180,22 @@ def forward_step( ...@@ -174,14 +180,22 @@ def forward_step(
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
# timers("forward-compute").stop() # timers("forward-compute").stop()
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if parallel_state.is_pipeline_stage_after_split() and model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor return output_tensor
return [output_tensor]
def backward_step( def backward_step(
input_tensor: Optional[torch.Tensor], input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor], output_tensor_grad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]: model_type: ModelType,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor. """Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss If last stage, output_tensor_grad is None, otherwise gradient of loss
...@@ -200,19 +214,39 @@ def backward_step( ...@@ -200,19 +214,39 @@ def backward_step(
# timers = get_timers() # timers = get_timers()
# timers("backward-compute").start() # timers("backward-compute").start()
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
unwrap_input_tensor_grad = not isinstance(input_tensor, list)
if unwrap_input_tensor_grad:
input_tensor = [input_tensor]
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass. # Backward pass.
# if output_tensor_grad is None: torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) # Collect the grad of the input_tensor.
input_tensor_grad = None input_tensor_grad = [None]
if input_tensor is not None: if input_tensor is not None:
input_tensor_grad = input_tensor.grad input_tensor_grad = []
# timers("backward-compute").stop() for x in input_tensor:
input_tensor_grad.append(None if x is None else x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
parallel_state.is_pipeline_stage_after_split() and
model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
input_tensor_grad[-1].add_(output_tensor_grad[1])
return input_tensor_grad # timers("backward-compute").stop()
return input_tensor_grad[0] if unwrap_input_tensor_grad else input_tensor_grad
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger from apex.transformer.log_util import get_transformer_logger
...@@ -58,6 +60,7 @@ def forward_backward_no_pipelining( ...@@ -58,6 +60,7 @@ def forward_backward_no_pipelining(
msg = f"`model` is expected be a `nn.Module`, but {type(model)}" msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg) raise RuntimeError(msg)
model = model[0] model = model[0]
model_type = get_model_type(model)
context_handler = placeholder_handler context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
...@@ -75,7 +78,7 @@ def forward_backward_no_pipelining( ...@@ -75,7 +78,7 @@ def forward_backward_no_pipelining(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced) forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`") _logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad) backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
...@@ -86,6 +89,6 @@ def forward_backward_no_pipelining( ...@@ -86,6 +89,6 @@ def forward_backward_no_pipelining(
) )
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`") _logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad) backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
return losses_reduced return losses_reduced
from typing import List, Union, Optional from typing import List, Union, Optional, Sequence
import torch import torch
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import p2p_communication from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.log_util import get_transformer_logger from apex.transformer.log_util import get_transformer_logger
...@@ -18,7 +20,7 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"] ...@@ -18,7 +20,7 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger = get_transformer_logger(__name__) _logger = get_transformer_logger(__name__)
# TODO (mkozuki): Reduce cyclomatic complexity # TODO(mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving( def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: List[Batch], batch: List[Batch],
...@@ -26,7 +28,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -26,7 +28,7 @@ def _forward_backward_pipelining_with_interleaving(
*, *,
forward_only: bool, forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
): ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed. """Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
...@@ -56,22 +58,22 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -56,22 +58,22 @@ def _forward_backward_pipelining_with_interleaving(
if not isinstance(model, list): if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'") raise RuntimeError("`model` must be a list of `nn.Module`'s'")
num_model_chunks = len(model) num_model_chunks: int = len(model)
input_tensors = [[] for _ in range(num_model_chunks)] input_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
output_tensors = [[] for _ in range(num_model_chunks)] output_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
curr_iters = [0 for _ in range(num_model_chunks)] curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
losses_reduced = [] losses_reduced: List[Union[None, torch.Tensor]] = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(num_model_chunks)] output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_microbatches = get_num_microbatches() * num_model_chunks num_microbatches: int = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False all_warmup_microbatches: bool = False
if forward_only: if forward_only:
num_warmup_microbatches = num_microbatches num_warmup_microbatches: int = num_microbatches
else: else:
# Run all forward passes and then all backward passes if number of # Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages. # microbatches is just the number of pipeline stages.
...@@ -86,7 +88,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -86,7 +88,7 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
_logger.info( _logger.info(
f"num_microbatches: {num_microbatches}, " f"num_microbatches: {num_microbatches}, "
...@@ -106,10 +108,11 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -106,10 +108,11 @@ def _forward_backward_pipelining_with_interleaving(
model_chunk_id = num_model_chunks - model_chunk_id - 1 model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def forward_step_helper(microbatch_id, curr_iters): def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor:
"""Helper method to run forward step with model split into chunks """Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step()).""" (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
"""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
...@@ -137,11 +140,13 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -137,11 +140,13 @@ def _forward_backward_pipelining_with_interleaving(
return output_tensor return output_tensor
def backward_step_helper(microbatch_id): def backward_step_helper(microbatch_id: int) -> torch.Tensor:
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step()).""" (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
"""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
model_type = get_model_type(model[model_chunk_id])
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
...@@ -150,7 +155,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -150,7 +155,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
return input_tensor_grad return input_tensor_grad
...@@ -200,7 +205,8 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -200,7 +205,8 @@ def _forward_backward_pipelining_with_interleaving(
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else: else:
_logger.debug("send fwd and receive fwd") _logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape) input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
################################################################################################################### ###################################################################################################################
...@@ -302,7 +308,8 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -302,7 +308,8 @@ def _forward_backward_pipelining_with_interleaving(
if k == (num_microbatches - 1): if k == (num_microbatches - 1):
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape) p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
) )
return losses_reduced return losses_reduced
from typing import Union, List, Optional from typing import Union, List, Optional, Sequence
import torch import torch
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import p2p_communication from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger from apex.transformer.log_util import get_transformer_logger
...@@ -19,6 +22,117 @@ __all__ = ["forward_backward_pipelining_without_interleaving"] ...@@ -19,6 +22,117 @@ __all__ = ["forward_backward_pipelining_without_interleaving"]
_logger = get_transformer_logger(__name__) _logger = get_transformer_logger(__name__)
def get_tensor_shapes(
rank: int,
model_type: ModelType,
*,
tensor_shape: Union[List[int], torch.Size],
decoder_sequence_length: Optional[int] = None,
) -> Sequence[Sequence[int]]:
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
assert (
len(tensor_shape) == 3
), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}"
sequence_length, micro_batch_size, hidden_size = tensor_shape
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if decoder_sequence_length is None:
raise ValueError("`decoder_sequence_length` is required for `ModelType.encoder_and_decoder`")
if parallel_state.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if parallel_state.is_pipeline_stage_before_split(rank + 1):
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
else:
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
else:
tensor_shapes.append((decoder_sequence_length, micro_batch_size, hidden_size))
tensor_shapes.append((micro_batch_size, sequence_length, hidden_size))
else:
tensor_shapes.append((sequence_length, micro_batch_size, hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
return input_tensors
def recv_backward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
return output_tensor_grads
def send_forward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]],
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
def send_backward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
def send_forward_recv_backward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]],
) -> List[Union[None, torch.Tensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
) -> List[Union[None, torch.Tensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving( def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Batch, batch: Batch,
...@@ -26,7 +140,8 @@ def forward_backward_pipelining_without_interleaving( ...@@ -26,7 +140,8 @@ def forward_backward_pipelining_without_interleaving(
*, *,
forward_only: bool, forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
): decoder_sequence_length: Optional[int] = None,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages. """Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps: This pipeline parallel scheduling consists of three steps:
...@@ -51,21 +166,28 @@ def forward_backward_pipelining_without_interleaving( ...@@ -51,21 +166,28 @@ def forward_backward_pipelining_without_interleaving(
""" """
# timers = get_timers() # timers = get_timers()
model = listify_model(model) model: List[torch.nn.Module] = listify_model(model)
if len(model) != 1: if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}" msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
raise RuntimeError(msg) raise RuntimeError(msg)
model = model[0] model: torch.nn.Module = model[0]
# Compute number of warmup microbatches. # Compute number of warmup microbatches.
num_microbatches = get_num_microbatches() num_microbatches: int = get_num_microbatches()
num_warmup_microbatches = ( num_warmup_microbatches: int = (
parallel_state.get_pipeline_model_parallel_world_size() parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1
- parallel_state.get_pipeline_model_parallel_rank() )
- 1 num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
model_type = get_model_type(model)
rank: int = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
)
send_tensor_shapes: List[List[int]] = get_tensor_shapes(
rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
) )
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
_logger.info( _logger.info(
f"num_microbatches: {num_microbatches}, " f"num_microbatches: {num_microbatches}, "
...@@ -74,13 +196,9 @@ def forward_backward_pipelining_without_interleaving( ...@@ -74,13 +196,9 @@ def forward_backward_pipelining_without_interleaving(
) )
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_tensors = None input_tensors: List[Union[None, torch.Tensor]] = []
output_tensors = None output_tensors: List[Union[None, torch.Tensor]] = []
if not forward_only: losses_reduced: List[Union[None, torch.Tensor]] = []
input_tensors = []
output_tensors = []
losses_reduced = []
################################################################################################################### ###################################################################################################################
# Run warmup forward passes. # Run warmup forward passes.
################################################################################################################### ###################################################################################################################
...@@ -88,11 +206,11 @@ def forward_backward_pipelining_without_interleaving( ...@@ -88,11 +206,11 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd") _logger.debug("receive fwd")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes)
cur_microbatch = get_kth_microbatch(batch, i) cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd") _logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) send_forward(output_tensor, tensor_shapes=send_tensor_shapes)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -103,7 +221,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -103,7 +221,7 @@ def forward_backward_pipelining_without_interleaving(
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start") _logger.debug("recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape) input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes)
################################################################################################################### ###################################################################################################################
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -111,21 +229,23 @@ def forward_backward_pipelining_without_interleaving( ...@@ -111,21 +229,23 @@ def forward_backward_pipelining_without_interleaving(
_logger.info("Steady phase") _logger.info("Steady phase")
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration = i == (num_microbatches_remaining - 1) last_iteration: bool = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches) cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func, cur_microbatch, model, input_tensor, losses_reduced
)
if forward_only: if forward_only:
_logger.debug("send fwd") _logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) send_forward(output_tensor, tensor_shapes=send_tensor_shapes)
if not last_iteration: if not last_iteration:
_logger.debug("receive fwd (last iteration)") _logger.debug("receive fwd (last iteration)")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes)
else: else:
_logger.debug("send fwd & receive bwd") _logger.debug("send fwd & receive bwd")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape) output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -135,18 +255,15 @@ def forward_backward_pipelining_without_interleaving( ...@@ -135,18 +255,15 @@ def forward_backward_pipelining_without_interleaving(
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step( input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
input_tensor, output_tensor, output_tensor_grad
)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
_logger.debug("send bwd") _logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
else: else:
_logger.debug("send bwd and receive fwd") _logger.debug("send bwd and receive fwd")
input_tensor = p2p_communication.send_backward_recv_forward( input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
input_tensor_grad, tensor_shape=tensor_shape)
################################################################################################################### ###################################################################################################################
# Run cooldown backward passes. # Run cooldown backward passes.
################################################################################################################### ###################################################################################################################
...@@ -158,13 +275,11 @@ def forward_backward_pipelining_without_interleaving( ...@@ -158,13 +275,11 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd") _logger.debug("receive bwd")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape) output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes)
input_tensor_grad = backward_step( input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
input_tensor, output_tensor, output_tensor_grad
)
_logger.debug("send bwd") _logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
return losses_reduced return losses_reduced
...@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel ...@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.microbatches import build_num_microbatches_calculator from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.pipeline_parallel._timers import _Timers from apex.transformer.pipeline_parallel._timers import _Timers
if multi_tensor_applier.available: if multi_tensor_applier.available:
...@@ -186,6 +187,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)): ...@@ -186,6 +187,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return unwrapped_model return unwrapped_model
def get_model_type(
model: torch.nn.Module,
) -> ModelType:
"""Get `model_type` of `model`.
If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``.
Args:
model
"""
return getattr(unwrap_model(model), "model_type", ModelType.encoder_or_decoder)
def calc_params_l2_norm(model: torch.nn.Module, bf16: bool): def calc_params_l2_norm(model: torch.nn.Module, bf16: bool):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
# args = get_args() # args = get_args()
......
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from apex import transformer from apex import transformer
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
...@@ -49,7 +50,9 @@ class MyModel(nn.Module): ...@@ -49,7 +50,9 @@ class MyModel(nn.Module):
self.input_tensor = None self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None: def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None: if self.input_tensor is None:
...@@ -61,6 +64,27 @@ def model_provider_func(hidden_size, pre_process, post_process) -> MyModel: ...@@ -61,6 +64,27 @@ def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process) return MyModel(hidden_size, pre_process, post_process)
def process_batch(batch):
if isinstance(batch, list):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(batch, model):
x = process_batch(batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss}
return y, loss_func
class IdentityLayer(torch.nn.Module): class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
......
...@@ -11,7 +11,6 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model ...@@ -11,7 +11,6 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving, _forward_backward_pipelining_with_interleaving,
) )
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches from apex.transformer.pipeline_parallel.utils import update_num_microbatches
...@@ -19,6 +18,7 @@ from apex.transformer.testing import global_vars ...@@ -19,6 +18,7 @@ from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import fwd_step_func
from apex.transformer.log_util import get_transformer_logger, set_logging_level from apex.transformer.log_util import get_transformer_logger, set_logging_level
from apex.transformer.testing.commons import model_provider_func from apex.transformer.testing.commons import model_provider_func
from apex.transformer._data import MegatronPretrainingRandomSampler from apex.transformer._data import MegatronPretrainingRandomSampler
...@@ -44,29 +44,7 @@ HIDDEN_SIZE = 16 ...@@ -44,29 +44,7 @@ HIDDEN_SIZE = 16
def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [(torch.randn(HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2)) for _ in range(num_samples)] return [(torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2)) for _ in range(num_samples)]
def process_batch(batch):
if isinstance(batch, (list, tuple)):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(micro_batch, model):
x = process_batch(micro_batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
# Run forward & backward with dynamic batch size. # Run forward & backward with dynamic batch size.
...@@ -122,7 +100,7 @@ def run_interleaved_with_dynamic_batch_size( ...@@ -122,7 +100,7 @@ def run_interleaved_with_dynamic_batch_size(
assert isinstance(batch, (list, tuple)) assert isinstance(batch, (list, tuple))
return [get_num_samples(b) for b in batch] return [get_num_samples(b) for b in batch]
tensor_shape = [micro_batch_size, HIDDEN_SIZE] tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE]
consumed_samples = 0 consumed_samples = 0
for i in range(NUM_ITERATIONS): for i in range(NUM_ITERATIONS):
update_num_microbatches(consumed_samples, consistency_check=False) update_num_microbatches(consumed_samples, consistency_check=False)
......
from typing import Optional, Union, List from typing import Optional
import torch import torch
import torch.nn as nn
import apex
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
...@@ -11,13 +9,14 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model ...@@ -11,13 +9,14 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import model_provider_func
from apex.transformer.testing.commons import fwd_step_func
from apex.transformer.log_util import get_transformer_logger, set_logging_level from apex.transformer.log_util import get_transformer_logger, set_logging_level
...@@ -35,62 +34,6 @@ fwd_bwd_functions = { ...@@ -35,62 +34,6 @@ fwd_bwd_functions = {
} }
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, pre_process: bool = False, post_process: bool = False) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(pre_process=pre_process, post_process=post_process)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
return self.layer(x)
return self.layer(self.input_tensor)
def model_provider_func(pre_process, post_process) -> MyModel:
return MyModel(pre_process, post_process)
def process_batch(batch):
if isinstance(batch, list):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(batch, model):
x = process_batch(batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss}
return y, loss_func
# TODO (mkozuki): Add a case with `autocast` and `GradScaler`. # TODO (mkozuki): Add a case with `autocast` and `GradScaler`.
# Run forward & backward for one minibatch. # Run forward & backward for one minibatch.
def forward_backward_func_template( def forward_backward_func_template(
...@@ -121,13 +64,14 @@ def forward_backward_func_template( ...@@ -121,13 +64,14 @@ def forward_backward_func_template(
model_provider_func, model_provider_func,
wrap_with_ddp=True, wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=hidden_size,
) )
assert isinstance(model, list) assert isinstance(model, list)
assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size) assert len(model) == (1 if virtual_pipeline_model_parallel_size is None else virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups, lr=1e-4) torch.optim.Adam(_param_groups, lr=1e-4)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size] tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size, hidden_size]
batch = (torch.randn(tensor_shape).cuda(),) batch = (torch.randn(tensor_shape).cuda(),)
tensor_shape[0] = micro_batch_size tensor_shape[0] = micro_batch_size
...@@ -183,6 +127,7 @@ if __name__ == "__main__": ...@@ -183,6 +127,7 @@ if __name__ == "__main__":
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n" f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}" f"{str(e)}"
) )
print(failures[-1])
finally: finally:
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment