Unverified Commit a0ed4151 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Format & Test Refactoring (#1325)

* try PyTorch custom TestCase class

* revert

* initial working example

* update

* data utils

* fix imports

* hardcode backend to nccl

* fix signature

* fix typo

* mapping

* set device

* init

* refactor x entropy

* remove unused import & destroy model parallel

* refactor random

* fix test

* remove migrated tests

* refactor

* init

* separate affine weight init

* init model parallel

* split more

* weight init fix part 1

* use cpu init for consistency btwn native and tensor parallel

* black

* add col parallel

* use a 3D tensor of square matrix for column parallel linear

* skip the failing cases

* migrate layers test

* pipeline parallel forward/backward

* fix typo

* fix typo

* fix

* fix pipeline world size

* black

* rm `run_pipeline_parallel_test` in favor of test_pipeline_parallel_fwd_bwd.py

* stop logging

* set log level

* black

* license and format

* fix

* skip tf32 as matrices are small

* remove potentially inappropriate license

* Apply suggestions from code review

* remove `TODO` comment

* `torch.testing.assert_allclose` -> `torch.testing.assert_close`

* remove comment-outs

* remote unused import

* minor fix
parent f10b4b89
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict from collections import defaultdict
import torch import torch
......
...@@ -31,7 +31,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -31,7 +31,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None return input_grads, None
...@@ -65,10 +69,10 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale): ...@@ -65,10 +69,10 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
# 2. Apply the mask. # 2. Apply the mask.
# 3. Perform softmax. # 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function): class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inputs, mask, scale): def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
...@@ -81,7 +85,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -81,7 +85,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None return input_grads, None, None
...@@ -120,7 +126,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -120,7 +126,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16: if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") raise RuntimeError(
"both fp16 and bf16 flags cannot be active at the same time."
)
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......
from typing import Optional
import logging import logging
import os import os
import threading
def get_transformer_logger(name: str) -> logging.Logger: def get_transformer_logger(name: str) -> logging.Logger:
...@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None: ...@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None:
verbosity verbosity
""" """
from apex import _library_root_logger from apex import _library_root_logger
_library_root_logger.setLevel(verbosity) _library_root_logger.setLevel(verbosity)
...@@ -24,11 +24,11 @@ _logger = get_transformer_logger(__name__) ...@@ -24,11 +24,11 @@ _logger = get_transformer_logger(__name__)
def build_num_microbatches_calculator( def build_num_microbatches_calculator(
rank: int, rank: int,
rampup_batch_size: Optional[List[int]], rampup_batch_size: Optional[List[int]],
global_batch_size: int, global_batch_size: int,
micro_batch_size: int, micro_batch_size: int,
data_parallel_size: int, data_parallel_size: int,
): ):
# Constant num micro-batches. # Constant num micro-batches.
if rampup_batch_size is None: if rampup_batch_size is None:
...@@ -37,7 +37,9 @@ def build_num_microbatches_calculator( ...@@ -37,7 +37,9 @@ def build_num_microbatches_calculator(
) )
if rank == 0: if rank == 0:
_logger.info( _logger.info(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()) "setting number of micro-batches to constant {}".format(
num_microbatches_calculator.get()
)
) )
else: else:
...@@ -54,7 +56,10 @@ def build_num_microbatches_calculator( ...@@ -54,7 +56,10 @@ def build_num_microbatches_calculator(
"will use batch size rampup starting from global batch " "will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments " "size {} to global batch size {} with batch size increments "
"{} over {} samples.".format( "{} over {} samples.".format(
start_batch_size, global_batch_size, batch_size_increment, ramup_samples start_batch_size,
global_batch_size,
batch_size_increment,
ramup_samples,
), ),
flush=True, flush=True,
) )
...@@ -91,7 +96,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator): ...@@ -91,7 +96,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
micro_batch_times_data_parallel = micro_batch_size * data_parallel_size micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, ( assert global_batch_size % micro_batch_times_data_parallel == 0, (
"global batch size ({}) is not divisible by micro batch size ({})" "global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})".format(global_batch_size, micro_batch_size, data_parallel_size) " times data parallel size ({})".format(
global_batch_size, micro_batch_size, data_parallel_size
)
) )
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
assert self.num_micro_batches >= 1 assert self.num_micro_batches >= 1
...@@ -131,7 +138,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -131,7 +138,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size self.micro_batch_times_data_parallel_size = (
self.micro_batch_size * self.data_parallel_size
)
assert self.micro_batch_times_data_parallel_size > 0 assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0 assert start_batch_size > 0
...@@ -163,15 +172,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -163,15 +172,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.current_global_batch_size = self.global_batch_size self.current_global_batch_size = self.global_batch_size
else: else:
steps = int(consumed_samples / self.rampup_samples_per_increment) steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment self.current_global_batch_size = (
self.start_batch_size + steps * self.batch_size_increment
)
assert self.current_global_batch_size <= self.global_batch_size assert self.current_global_batch_size <= self.global_batch_size
if consistency_check: if consistency_check:
assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, ( assert (
self.current_global_batch_size
% self.micro_batch_times_data_parallel_size
== 0
), (
"current global " "current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times" "batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})".format( "data parallel size ({})".format(
self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size,
) )
) )
self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size self.num_micro_batches = (
self.current_global_batch_size // self.micro_batch_times_data_parallel_size
)
...@@ -24,6 +24,7 @@ from apex.transformer.log_util import get_transformer_logger ...@@ -24,6 +24,7 @@ from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__) _logger = get_transformer_logger(__name__)
# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) = # set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# { # {
# 'get_num_layers', # 'get_num_layers',
......
import warnings
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving, forward_backward_pipelining_without_interleaving,
) )
__all__ = [
"get_forward_backward_func",
]
class ExperimentalWarning(Warning): class ExperimentalWarning(Warning):
pass pass
...@@ -21,19 +27,9 @@ def get_forward_backward_func( ...@@ -21,19 +27,9 @@ def get_forward_backward_func(
if get_num_microbatches() % pipeline_model_parallel_size != 0: if get_num_microbatches() % pipeline_model_parallel_size != 0:
msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise RuntimeError(msg) raise RuntimeError(msg)
warnings.warn(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
ExperimentalWarning
)
forward_backward_func = _forward_backward_pipelining_with_interleaving forward_backward_func = _forward_backward_pipelining_with_interleaving
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
return forward_backward_func return forward_backward_func
__all__ = [
"get_forward_backward_func",
]
...@@ -10,7 +10,9 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches ...@@ -10,7 +10,9 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import get_model_type from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from apex.transformer.tensor_parallel.layers import (
set_defaults_if_not_set_tensor_model_parallel_attributes,
)
from apex.transformer.log_util import get_transformer_logger from apex.transformer.log_util import get_transformer_logger
...@@ -19,16 +21,18 @@ _logger = get_transformer_logger(__name__) ...@@ -19,16 +21,18 @@ _logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]] Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor] LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]] FwdStepFunc = Callable[
[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]
]
def build_model( def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True, wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None, virtual_pipeline_model_parallel_size: Optional[int] = None,
model_type: ModelType = ModelType.encoder_or_decoder, model_type: ModelType = ModelType.encoder_or_decoder,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> List[torch.nn.Module]: ) -> List[torch.nn.Module]:
"""Build the model satisfying pipeline model parallel requirements. """Build the model satisfying pipeline model parallel requirements.
...@@ -49,8 +53,8 @@ def build_model( ...@@ -49,8 +53,8 @@ def build_model(
the list has multiple models, otherwise one. the list has multiple models, otherwise one.
""" """
if ( if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.get_pipeline_model_parallel_world_size() > 1
virtual_pipeline_model_parallel_size is not None and virtual_pipeline_model_parallel_size is not None
): ):
model = [] model = []
for i in range(virtual_pipeline_model_parallel_size): for i in range(virtual_pipeline_model_parallel_size):
...@@ -60,10 +64,9 @@ def build_model( ...@@ -60,10 +64,9 @@ def build_model(
# Set pre_process and post_process only after virtual rank is set. # Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage() pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage() post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({ cur_kwargs.update(
"pre_process": pre_process, {"pre_process": pre_process, "post_process": post_process,}
"post_process": post_process, )
})
this_model = model_provider_func(*cur_args, **cur_kwargs) this_model = model_provider_func(*cur_args, **cur_kwargs)
model.append(this_model) model.append(this_model)
else: else:
...@@ -72,10 +75,9 @@ def build_model( ...@@ -72,10 +75,9 @@ def build_model(
if model_type == ModelType.encoder_or_decoder: if model_type == ModelType.encoder_or_decoder:
pre_process = parallel_state.is_pipeline_first_stage() pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage() post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({ cur_kwargs.update(
"pre_process": pre_process, {"pre_process": pre_process, "post_process": post_process,}
"post_process": post_process, )
})
model = model_provider_func(*cur_args, **cur_kwargs) model = model_provider_func(*cur_args, **cur_kwargs)
elif model_type == ModelType.encoder_and_decoder: elif model_type == ModelType.encoder_and_decoder:
pre_process = parallel_state.is_pipeline_first_stage() pre_process = parallel_state.is_pipeline_first_stage()
...@@ -94,12 +96,14 @@ def build_model( ...@@ -94,12 +96,14 @@ def build_model(
post_process = rank == (split_rank - 1) or rank == (world_size - 1) post_process = rank == (split_rank - 1) or rank == (world_size - 1)
add_encoder = parallel_state.is_pipeline_stage_before_split() add_encoder = parallel_state.is_pipeline_stage_before_split()
add_decoder = parallel_state.is_pipeline_stage_after_split() add_decoder = parallel_state.is_pipeline_stage_after_split()
cur_kwargs.update({ cur_kwargs.update(
"pre_process": pre_process, {
"post_process": post_process, "pre_process": pre_process,
"add_encoder": add_encoder, "post_process": post_process,
"add_decoder": add_decoder, "add_encoder": add_encoder,
}) "add_decoder": add_decoder,
}
)
model = model_provider_func(*cur_args, **cur_kwargs) model = model_provider_func(*cur_args, **cur_kwargs)
model.model_type = model_type model.model_type = model_type
...@@ -115,7 +119,10 @@ def build_model( ...@@ -115,7 +119,10 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes(param) set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0: if (
parallel_state.model_parallel_is_initialized()
and parallel_state.get_data_parallel_rank() == 0
):
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(),
...@@ -143,41 +150,54 @@ def build_model( ...@@ -143,41 +150,54 @@ def build_model(
def _calc_number_of_params(model: List[torch.nn.Module]) -> int: def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
assert isinstance(model, list) assert isinstance(model, list)
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) return sum(
[
sum([p.nelement() for p in model_module.parameters()])
for model_module in model
]
)
def _get_params_for_weight_decay_optimization( def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
*, *,
no_weight_decay_modules=(FusedLayerNorm,), no_weight_decay_modules=(FusedLayerNorm,),
) -> Dict[str, torch.nn.Parameter]: ) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will. Layernorms and biases will have no weight decay but the rest will.
""" """
modules = listify_model(model) modules = listify_model(model)
weight_decay_params = {'params': []} weight_decay_params = {"params": []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module in modules: for module in modules:
for module_ in module.modules(): for module_ in module.modules():
if isinstance(module_, no_weight_decay_modules): if isinstance(module_, no_weight_decay_modules):
no_weight_decay_params['params'].extend( no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) [p for p in list(module_._parameters.values()) if p is not None]
if p is not None]) )
else: else:
weight_decay_params['params'].extend( weight_decay_params["params"].extend(
[p for n, p in list(module_._parameters.items()) [
if p is not None and n != 'bias']) p
no_weight_decay_params['params'].extend( for n, p in list(module_._parameters.items())
[p for n, p in list(module_._parameters.items()) if p is not None and n != "bias"
if p is not None and n == 'bias']) ]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
def free_output_tensor( def free_output_tensor(
output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
deallocate_pipeline_outputs: bool = False deallocate_pipeline_outputs: bool = False,
) -> None: ) -> None:
"""Pseudo-free the output tensor's `.data` field. """Pseudo-free the output tensor's `.data` field.
...@@ -202,9 +222,15 @@ def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) - ...@@ -202,9 +222,15 @@ def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not. output and grad have the same shape, while C++ `backward` does not.
""" """
assert output.numel() == 1, "output should be pseudo-freed in schedule, to optimize memory consumption" assert (
assert isinstance(output, torch.Tensor), "output == {}.".format(type(output).__name__) output.numel() == 1
assert isinstance(grad_output, (torch.Tensor, type(None))), "grad_outptu == {}.".format(type(grad_output).__name__) ), "output should be pseudo-freed in schedule, to optimize memory consumption"
assert isinstance(output, torch.Tensor), "output == {}.".format(
type(output).__name__
)
assert isinstance(
grad_output, (torch.Tensor, type(None))
), "grad_outptu == {}.".format(type(grad_output).__name__)
# Handle scalar output # Handle scalar output
if grad_output is None: if grad_output is None:
...@@ -224,13 +250,13 @@ def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) - ...@@ -224,13 +250,13 @@ def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -
def forward_step( def forward_step(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Optional[Batch], batch: Optional[Batch],
model: torch.nn.Module, model: torch.nn.Module,
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor], losses_reduced: List[torch.Tensor],
dtype: torch.dtype, dtype: torch.dtype,
disable_autocast: bool = False, disable_autocast: bool = False,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]: ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -264,8 +290,8 @@ def forward_step( ...@@ -264,8 +290,8 @@ def forward_step(
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
with torch.cuda.amp.autocast( with torch.cuda.amp.autocast(
enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16), enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16),
dtype=dtype, dtype=dtype,
): ):
output_tensor, loss_func = forward_step_func(batch, model) output_tensor, loss_func = forward_step_func(batch, model)
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
...@@ -278,7 +304,10 @@ def forward_step( ...@@ -278,7 +304,10 @@ def forward_step(
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state # and in decoder stack, then send encoder_hidden_state
# downstream as well. # downstream as well.
if parallel_state.is_pipeline_stage_after_split() and model_type == ModelType.encoder_and_decoder: if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]] return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor: if unwrap_output_tensor:
return output_tensor return output_tensor
...@@ -286,13 +315,13 @@ def forward_step( ...@@ -286,13 +315,13 @@ def forward_step(
def backward_step( def backward_step(
input_tensor: Optional[torch.Tensor], input_tensor: Optional[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor], output_tensor_grad: Optional[torch.Tensor],
model_type: ModelType, model_type: ModelType,
*, *,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
deallocate_pipeline_outputs: bool = False, deallocate_pipeline_outputs: bool = False,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]: ) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor. """Backward step through passed-in output tensor.
...@@ -343,9 +372,9 @@ def backward_step( ...@@ -343,9 +372,9 @@ def backward_step(
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder). # Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if ( if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.get_pipeline_model_parallel_world_size() > 1
parallel_state.is_pipeline_stage_after_split() and and parallel_state.is_pipeline_stage_after_split()
model_type == ModelType.encoder_and_decoder and model_type == ModelType.encoder_and_decoder
): ):
if output_tensor_grad[1] is not None: if output_tensor_grad[1] is not None:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`? # todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
......
...@@ -29,16 +29,16 @@ def placeholder_handler(): ...@@ -29,16 +29,16 @@ def placeholder_handler():
def forward_backward_no_pipelining( def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Batch, batch: Batch,
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
*, *,
forward_only: bool, forward_only: bool,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False, disable_autocast: bool = False,
custom_sync_context_handler = None, custom_sync_context_handler=None,
**kwargs, **kwargs,
): ):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication). """Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
...@@ -98,7 +98,13 @@ def forward_backward_no_pipelining( ...@@ -98,7 +98,13 @@ def forward_backward_no_pipelining(
) )
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`") _logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler) backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
...@@ -115,6 +121,12 @@ def forward_backward_no_pipelining( ...@@ -115,6 +121,12 @@ def forward_backward_no_pipelining(
) )
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`") _logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler) backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
)
return losses_reduced return losses_reduced
...@@ -23,17 +23,17 @@ _logger = get_transformer_logger(__name__) ...@@ -23,17 +23,17 @@ _logger = get_transformer_logger(__name__)
# TODO(mkozuki): Reduce cyclomatic complexity # TODO(mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving( def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: List[Optional[Batch]], batch: List[Optional[Batch]],
model: List[torch.nn.Module], model: List[torch.nn.Module],
*, *,
forward_only: bool, forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False, disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False, deallocate_pipeline_outputs: bool = False,
**kwargs, **kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed. """Run interleaved 1F1B schedule with communication between pipeline stages as needed.
...@@ -71,12 +71,18 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -71,12 +71,18 @@ def _forward_backward_pipelining_with_interleaving(
raise RuntimeError("`model` must be a list of `nn.Module`'s'") raise RuntimeError("`model` must be a list of `nn.Module`'s'")
num_model_chunks: int = len(model) num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)] input_tensors: List[List[Union[None, torch.Tensor]]] = [
output_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)] [] for _ in range(num_model_chunks)
]
output_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
curr_iters: List[int] = [0 for _ in range(num_model_chunks)] curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
losses_reduced: List[Union[None, torch.Tensor]] = [] losses_reduced: List[Union[None, torch.Tensor]] = []
if not forward_only: if not forward_only:
output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)] output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size() pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank() pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank()
...@@ -97,7 +103,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -97,7 +103,9 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = num_microbatches num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True all_warmup_microbatches = True
else: else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches = (
pipeline_parallel_size - pipeline_parallel_rank - 1
) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
...@@ -114,7 +122,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -114,7 +122,9 @@ def _forward_backward_pipelining_with_interleaving(
def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
"""Helper function to get the model chunk ID given the iteration number.""" """Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) microbatch_id_in_group = microbatch_id % (
pipeline_parallel_size * num_model_chunks
)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward: if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1 model_chunk_id = num_model_chunks - model_chunk_id - 1
...@@ -129,10 +139,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -129,10 +139,9 @@ def _forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step # forward step
if ( if parallel_state.is_pipeline_first_stage() and len(
parallel_state.is_pipeline_first_stage() and input_tensors[model_chunk_id]
len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]) ) == len(output_tensors[model_chunk_id]):
):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1] input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step( output_tensor = forward_step(
...@@ -169,7 +178,14 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -169,7 +178,14 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs) input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
return input_tensor_grad return input_tensor_grad
...@@ -177,7 +193,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -177,7 +193,9 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes. # Run warmup forward passes.
################################################################################################################### ###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)) input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)
)
_logger.info("Warmup phase") _logger.info("Warmup phase")
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
...@@ -191,7 +209,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -191,7 +209,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False recv_prev = False
if k == (num_microbatches - 1): if k == (num_microbatches - 1):
recv_prev = False recv_prev = False
_logger.debug(f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}") _logger.debug(
f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}"
)
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
...@@ -200,7 +220,11 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -200,7 +220,11 @@ def _forward_backward_pipelining_with_interleaving(
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration). # in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None input_tensor_grad = None
recv_next = True recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
...@@ -221,7 +245,11 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -221,7 +245,11 @@ def _forward_backward_pipelining_with_interleaving(
else: else:
_logger.debug("send fwd and receive fwd") _logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward( input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype) output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs) free_output_tensor(output_tensor, deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
...@@ -251,7 +279,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -251,7 +279,9 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
_logger.debug(f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}") _logger.debug(
f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}"
)
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None input_tensor_grad = None
...@@ -267,7 +297,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -267,7 +297,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False recv_prev = False
next_forward_model_chunk_id += 1 next_forward_model_chunk_id += 1
else: else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(
forward_k + 1, forward=True
)
recv_next = True recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
...@@ -279,7 +311,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -279,7 +311,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False recv_next = False
next_backward_model_chunk_id -= 1 next_backward_model_chunk_id -= 1
else: else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(
backward_k + 1, forward=False
)
# If last iteration, don't receive; we already received one extra # If last iteration, don't receive; we already received one extra
# before the start of the for loop. # before the start of the for loop.
...@@ -314,9 +348,13 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -314,9 +348,13 @@ def _forward_backward_pipelining_with_interleaving(
_logger.info("Cooldown phase") _logger.info("Cooldown phase")
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)) output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)
)
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})") _logger.debug(
f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})"
)
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True recv_next = True
...@@ -327,7 +365,11 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -327,7 +365,11 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype) input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
)
) )
return losses_reduced return losses_reduced
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5 ...@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype): def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type.""" """Check that all the keys have the same target data type."""
for key in keys: for key in keys:
assert data[key].dtype == target_dtype, "{} has data type {} which " "is different than {}".format( assert data[key].dtype == target_dtype, (
key, data[key].dtype, target_dtype "{} has data type {} which "
"is different than {}".format(key, data[key].dtype, target_dtype)
) )
...@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast. # Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes) sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast( torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(), sizes_cuda,
get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group(),
) )
# Move back to cpu and unpack. # Move back to cpu and unpack.
...@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype): ...@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype):
# Check that all keys have the same data type. # Check that all keys have the same data type.
_check_data_types(keys, data, datatype) _check_data_types(keys, data, datatype)
# Flatten the data associated with the keys # Flatten the data associated with the keys
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() flatten_data = torch.cat(
[data[key].contiguous().view(-1) for key in keys], dim=0
).cuda()
else: else:
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) flatten_data = torch.empty(
total_numel, device=torch.cuda.current_device(), dtype=datatype
)
# Broadcast # Broadcast
torch.distributed.broadcast( torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(), flatten_data,
get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group(),
) )
# Unpack # Unpack
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,10 +26,18 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group ...@@ -26,10 +26,18 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.utils import divide from apex.transformer.utils import divide
from apex.transformer.tensor_parallel.mappings import copy_to_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import (
from apex.transformer.tensor_parallel.mappings import gather_from_tensor_model_parallel_region copy_to_tensor_model_parallel_region,
from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region )
from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
reduce_from_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
scatter_to_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger from apex.transformer.log_util import get_transformer_logger
...@@ -53,9 +61,9 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { ...@@ -53,9 +61,9 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or ( return (
get_tensor_model_parallel_rank() == 0 hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
) ) or (get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
...@@ -89,7 +97,9 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): ...@@ -89,7 +97,9 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
with get_cuda_rng_tracker().fork(): with get_cuda_rng_tracker().fork():
init_method(weight) init_method(weight)
...@@ -114,16 +124,22 @@ def _initialize_affine_weight_cpu( ...@@ -114,16 +124,22 @@ def _initialize_affine_weight_cpu(
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
the relevant chunk.""" the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
# Initialize master weight # Initialize master weight
master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) master_weight = torch.empty(
output_size, input_size, dtype=torch.float, requires_grad=False
)
init_method(master_weight) init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype) master_weight = master_weight.to(dtype=params_dtype)
# Split and copy # Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) weight_list = torch.split(
master_weight, per_partition_per_stride_size, dim=partition_dim
)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
...@@ -147,7 +163,13 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -147,7 +163,13 @@ class VocabParallelEmbedding(torch.nn.Module):
""" """
def __init__( def __init__(
self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False, self,
num_embeddings,
embedding_dim,
init_method=init.xavier_normal_,
*,
params_dtype=torch.float32,
use_cpu_initialization=False,
): ):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions. # Keep the input dimensions.
...@@ -162,18 +184,34 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -162,18 +184,34 @@ class VocabParallelEmbedding(torch.nn.Module):
self._weight = None self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( (
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
self.num_embeddings_per_partition = (
self.vocab_end_index - self.vocab_start_index
) )
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize. # Allocate weights and initialize.
if use_cpu_initialization: if use_cpu_initialization:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype) torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
dtype=params_dtype,
)
) )
_initialize_affine_weight_cpu( _initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=params_dtype, params_dtype=params_dtype,
) )
else: else:
...@@ -185,12 +223,16 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -185,12 +223,16 @@ class VocabParallelEmbedding(torch.nn.Module):
dtype=params_dtype, dtype=params_dtype,
) )
) )
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) _initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=1
)
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | (
input_ >= self.vocab_end_index
)
# Mask the input. # Mask the input.
masked_input = input_.clone() - self.vocab_start_index masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0 masked_input[input_mask] = 0
...@@ -216,8 +258,11 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -216,8 +258,11 @@ class VocabParallelEmbedding(torch.nn.Module):
class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop.""" """Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce): def forward(
ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
...@@ -233,17 +278,23 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -233,17 +278,23 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
use_bias = ctx.use_bias use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
input = input.view(input.shape[0] * input.shape[1], input.shape[2]) input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True) handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad) fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
input, grad_output, weight.main_grad
)
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(input) grad_weight = grad_output.t().matmul(input)
...@@ -255,21 +306,22 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -255,21 +306,22 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
def linear_with_grad_accumulation_and_async_allreduce( def linear_with_grad_accumulation_and_async_allreduce(
input, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
): ):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce) args = _cast_if_autocast_enabled(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduce.apply(*args) return LinearWithGradAccumulationAndAsyncAllreduce.apply(*args)
class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop.""" """Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce): def forward(
ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
...@@ -285,17 +337,23 @@ class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function ...@@ -285,17 +337,23 @@ class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function
use_bias = ctx.use_bias use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
input = input.view(input.shape[0] * input.shape[1], input.shape[2]) input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True) handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input, grad_output, weight.main_grad) fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
input, grad_output, weight.main_grad
)
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(input) grad_weight = grad_output.t().matmul(input)
...@@ -307,13 +365,11 @@ class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function ...@@ -307,13 +365,11 @@ class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function
def linear_with_grad_accumulation_and_async_allreduce_in16bit( def linear_with_grad_accumulation_and_async_allreduce_in16bit(
input, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
): ):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce) args = _cast_if_autocast_enabled(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduceIn16Bit.apply(*args) return LinearWithGradAccumulationAndAsyncAllreduceIn16Bit.apply(*args)
...@@ -382,7 +438,11 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -382,7 +438,11 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
if use_cpu_initialization: if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype)) self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=params_dtype
)
)
self.master_weight = _initialize_affine_weight_cpu( self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.weight,
self.output_size, self.output_size,
...@@ -403,14 +463,22 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -403,14 +463,22 @@ class ColumnParallelLinear(torch.nn.Module):
dtype=params_dtype, dtype=params_dtype,
) )
) )
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride) _initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
if bias: if bias:
if use_cpu_initialization: if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype)) self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
else: else:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype) torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero. # Always initialize bias to zero.
...@@ -420,8 +488,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -420,8 +488,8 @@ class ColumnParallelLinear(torch.nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and not no_async_tensor_model_parallel_allreduce and world_size > 1
world_size > 1) )
if gradient_accumulation_fusion: if gradient_accumulation_fusion:
if not _grad_accum_fusion_available: if not _grad_accum_fusion_available:
# Basically, apex.transformer module users are expected to install APEX's # Basically, apex.transformer module users are expected to install APEX's
...@@ -429,6 +497,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -429,6 +497,7 @@ class ColumnParallelLinear(torch.nn.Module):
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." # `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository. # at the root of APEX repository.
import warnings import warnings
warnings.warn( warnings.warn(
"`gradient_accumulation_fusion` is set to `True` but " "`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not " "the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
...@@ -438,7 +507,11 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -438,7 +507,11 @@ class ColumnParallelLinear(torch.nn.Module):
gradient_accumulation_fusion = False gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion self.gradient_accumulation_fusion = gradient_accumulation_fusion
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce self._forward_impl = (
linear_with_grad_accumulation_and_async_allreduce_in16bit
if accumulation_in_fp16
else linear_with_grad_accumulation_and_async_allreduce
)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -450,8 +523,12 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -450,8 +523,12 @@ class ColumnParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
# Matrix multiply. # Matrix multiply.
output_parallel = self._forward_impl( output_parallel = self._forward_impl(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion, input_parallel,
self.async_tensor_model_parallel_allreduce) self.weight,
bias,
self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce,
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
...@@ -522,7 +599,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -522,7 +599,11 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. # Initialize weight.
if use_cpu_initialization: if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype)) self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=params_dtype
)
)
self.master_weight = _initialize_affine_weight_cpu( self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.weight,
self.output_size, self.output_size,
...@@ -543,13 +624,19 @@ class RowParallelLinear(torch.nn.Module): ...@@ -543,13 +624,19 @@ class RowParallelLinear(torch.nn.Module):
dtype=params_dtype, dtype=params_dtype,
) )
) )
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) _initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
if bias: if bias:
if use_cpu_initialization: if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
else: else:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=params_dtype) torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
......
...@@ -66,7 +66,9 @@ def _gather(input_): ...@@ -66,7 +66,9 @@ def _gather(input_):
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) torch.distributed.all_gather(
tensor_list, input_, group=get_tensor_model_parallel_group()
)
# Note: torch.cat already creates a contiguous tensor. # Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous() output = torch.cat(tensor_list, dim=last_dim).contiguous()
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -49,13 +50,20 @@ class MemoryBuffer: ...@@ -49,13 +50,20 @@ class MemoryBuffer:
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
print( print(
"> building the {} memory buffer with {} num elements " "> building the {} memory buffer with {} num elements "
"and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024), "and {} dtype ({:.1f} MB)...".format(
name, numel, dtype, numel * element_size / 1024 / 1024
),
flush=True, flush=True,
) )
self.name = name self.name = name
self.numel = numel self.numel = numel
self.dtype = dtype self.dtype = dtype
self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) self.data = torch.empty(
self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Index tracking the start of the free memory. # Index tracking the start of the free memory.
self._start = 0 self._start = 0
...@@ -81,13 +89,17 @@ class MemoryBuffer: ...@@ -81,13 +89,17 @@ class MemoryBuffer:
def add(self, tensor): def add(self, tensor):
"""Allocate a chunk of memory from the buffer to tensor and copy """Allocate a chunk of memory from the buffer to tensor and copy
the values.""" the values."""
assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format( assert (
tensor.dtype == self.dtype
), "Input tensor type {} different from buffer type {}".format(
tensor.dtype, self.dtype tensor.dtype, self.dtype
) )
# Number of elements of the input tensor. # Number of elements of the input tensor.
tensor_numel = torch.numel(tensor) tensor_numel = torch.numel(tensor)
new_start = self._start + tensor_numel new_start = self._start + tensor_numel
assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format( assert (
new_start <= self.numel
), "Not enough memory left in the buffer ({} > {})".format(
tensor_numel, self.numel - self._start tensor_numel, self.numel - self._start
) )
# New tensor is a view into the memory. # New tensor is a view into the memory.
...@@ -124,7 +136,8 @@ class RingMemBuffer: ...@@ -124,7 +136,8 @@ class RingMemBuffer:
def __init__(self, name, num_buffers, numel, dtype, track_usage): def __init__(self, name, num_buffers, numel, dtype, track_usage):
self.num_buffers = num_buffers self.num_buffers = num_buffers
self.buffers = [ self.buffers = [
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) for i in range(num_buffers) allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage)
for i in range(num_buffers)
] ]
self._index = -1 self._index = -1
......
...@@ -53,8 +53,15 @@ def init_checkpointed_activations_memory_buffer( ...@@ -53,8 +53,15 @@ def init_checkpointed_activations_memory_buffer(
): ):
"""Initializ the memory buffer for the checkpointed activations.""" """Initializ the memory buffer for the checkpointed activations."""
per_layer = micro_batch_size * max_position_embeddings * hidden_size // tensor_model_parallel_size per_layer = (
assert num_layers % checkpoint_num_layers == 0, "number of layers is not divisible by checkpoint-num-layers" micro_batch_size
* max_position_embeddings
* hidden_size
// tensor_model_parallel_size
)
assert (
num_layers % checkpoint_num_layers == 0
), "number of layers is not divisible by checkpoint-num-layers"
num_checkpointer_layers = num_layers // checkpoint_num_layers num_checkpointer_layers = num_layers // checkpoint_num_layers
numel = per_layer * num_checkpointer_layers numel = per_layer * num_checkpointer_layers
dtype = torch.half dtype = torch.half
...@@ -217,7 +224,9 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -217,7 +224,9 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state. # and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) _CUDA_RNG_STATE_TRACKER.add(
_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed
)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file. # TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
...@@ -255,7 +264,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -255,7 +264,10 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = gather_split_1d_tensor(inputs[0].data)
...@@ -284,7 +296,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -284,7 +296,10 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs
)
return (None,) + grads return (None,) + grads
......
...@@ -43,7 +43,9 @@ class VocabUtility: ...@@ -43,7 +43,9 @@ class VocabUtility:
partition: Note that indices in [fist, last)""" partition: Note that indices in [fist, last)"""
@staticmethod @staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
):
index_f = rank * per_partition_vocab_size index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size index_l = index_f + per_partition_vocab_size
return index_f, index_l return index_f, index_l
...@@ -51,4 +53,6 @@ class VocabUtility: ...@@ -51,4 +53,6 @@ class VocabUtility:
@staticmethod @staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
...@@ -21,7 +21,9 @@ import torch ...@@ -21,7 +21,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from apex import transformer from apex import transformer
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
...@@ -30,7 +32,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)" ...@@ -30,7 +32,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes. # note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module): class MyLayer(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool, post_process: bool): def __init__(self, hidden_size: int, pre_process: bool, post_process: bool):
super().__init__() super().__init__()
self.pre_process = pre_process self.pre_process = pre_process
...@@ -40,16 +41,22 @@ class MyLayer(nn.Module): ...@@ -40,16 +41,22 @@ class MyLayer(nn.Module):
def forward(self, x): def forward(self, x):
return self.layer(x) return self.layer(x)
class MyModel(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool = False, post_process: bool = False) -> None: class MyModel(nn.Module):
def __init__(
self, hidden_size: int, pre_process: bool = False, post_process: bool = False
) -> None:
super().__init__() super().__init__()
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.layer = MyLayer(hidden_size=hidden_size, pre_process=pre_process, post_process=post_process) self.layer = MyLayer(
hidden_size=hidden_size, pre_process=pre_process, post_process=post_process
)
self.input_tensor = None self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None: def set_input_tensor(
self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]
) -> None:
if not isinstance(input_tensor, list): if not isinstance(input_tensor, list):
input_tensor = [input_tensor] input_tensor = [input_tensor]
self.input_tensor = input_tensor[0] self.input_tensor = input_tensor[0]
...@@ -81,7 +88,8 @@ def fwd_step_func(batch, model): ...@@ -81,7 +88,8 @@ def fwd_step_func(batch, model):
def loss_func(x): def loss_func(x):
loss = torch.sum(x) loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss} return loss, {"avg": averaged_loss}
return y, loss_func return y, loss_func
...@@ -102,7 +110,7 @@ def set_random_seed(seed): ...@@ -102,7 +110,7 @@ def set_random_seed(seed):
transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed) transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'): def initialize_distributed(backend="nccl"):
"""Initialize torch.distributed.""" """Initialize torch.distributed."""
# Get local rank in case it is provided. # Get local rank in case it is provided.
# parser = argparse.ArgumentParser() # parser = argparse.ArgumentParser()
...@@ -113,11 +121,13 @@ def initialize_distributed(backend='nccl'): ...@@ -113,11 +121,13 @@ def initialize_distributed(backend='nccl'):
local_rank = args.local_rank local_rank = args.local_rank
# Get rank and world size. # Get rank and world size.
rank = int(os.getenv('RANK', '0')) rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", '1')) world_size = int(os.getenv("WORLD_SIZE", "1"))
print('> initializing torch.distributed with local rank: {}, ' print(
'rank: {}, world size: {}'.format(local_rank, rank, world_size)) "> initializing torch.distributed with local rank: {}, "
"rank: {}, world size: {}".format(local_rank, rank, world_size)
)
# Set the device id. # Set the device id.
device = rank % torch.cuda.device_count() device = rank % torch.cuda.device_count()
...@@ -126,22 +136,20 @@ def initialize_distributed(backend='nccl'): ...@@ -126,22 +136,20 @@ def initialize_distributed(backend='nccl'):
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Call the init process. # Call the init process.
init_method = 'tcp://' init_method = "tcp://"
master_ip = os.getenv('MASTER_ADDR', 'localhost') master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = os.getenv('MASTER_PORT', '6000') master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ':' + master_port init_method += master_ip + ":" + master_port
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend, world_size=world_size, rank=rank, init_method=init_method
world_size=world_size, )
rank=rank,
init_method=init_method)
def print_separator(message): def print_separator(message):
torch.distributed.barrier() torch.distributed.barrier()
filler_len = (78 - len(message)) // 2 filler_len = (78 - len(message)) // 2
filler = '-' * filler_len filler = "-" * filler_len
string = '\n' + filler + ' {} '.format(message) + filler string = "\n" + filler + " {} ".format(message) + filler
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(string, flush=True) print(string, flush=True)
torch.distributed.barrier() torch.distributed.barrier()
import sys
import torch
from torch import distributed as dist
from torch.testing._internal import common_utils
from torch.testing._internal import common_distributed
class DistributedTestBase(common_distributed.MultiProcessTestCase):
BACKEND_NCCL = "nccl"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self) -> None:
super().tearDown()
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)
@property
def init_method(self):
return f"{common_utils.FILE_SCHEMA}{self.file_name}"
@classmethod
def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name)
self.assertTrue(torch.cuda.is_available())
self.rank = rank
self.file_name = file_name
print(f"[dist init] rank = {self.rank}, world_size = {self.world_size}")
try:
dist.init_process_group(
init_method=self.init_method,
backend=DistributedTestBase.BACKEND_NCCL,
world_size=int(self.world_size),
rank=self.rank,
)
except RuntimeError as e:
if "recompile" in e.args[0]:
print(f"Backend of {DistributedTestBase.BACKEND_NCCL} not available")
sys.exit(0)
raise
torch.cuda.set_device(self.rank % torch.cuda.device_count())
dist.barrier()
self.run_test(test_name, pipe)
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
...@@ -6,7 +6,9 @@ from apex.transformer import parallel_state ...@@ -6,7 +6,9 @@ from apex.transformer import parallel_state
def ensure_divisibility(numerator, denominator): def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator.""" """Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator): def divide(numerator, denominator):
...@@ -19,7 +21,9 @@ def divide(numerator, denominator): ...@@ -19,7 +21,9 @@ def divide(numerator, denominator):
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) data = tensor.view(-1)
partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] return data[start_index:end_index]
...@@ -30,7 +34,14 @@ def gather_split_1d_tensor(tensor): ...@@ -30,7 +34,14 @@ def gather_split_1d_tensor(tensor):
world_size = parallel_state.get_tensor_model_parallel_world_size() world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor) numel = torch.numel(tensor)
numel_gathered = world_size * numel numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) gathered = torch.empty(
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] numel_gathered,
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(
chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()
)
return gathered return gathered
import subprocess import subprocess
import os import os
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
def run_gpt(cmd): def run_gpt(cmd):
args = list(cmd.split(' ')) args = list(cmd.split(" "))
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outs, errs = p.communicate() outs, errs = p.communicate()
outs = list(str((outs).decode('utf-8')).splitlines()) outs = list(str((outs).decode("utf-8")).splitlines())
success = False success = False
runtime = 0 runtime = 0
num_params = 0 num_params = 0
for out in outs: for out in outs:
out=str(out) out = str(out)
if "Average Iteration Time:" in str(out): if "Average Iteration Time:" in str(out):
slicey = out[out.find(':')+2:] slicey = out[out.find(":") + 2 :]
try: try:
runtime = float(slicey) runtime = float(slicey)
except: except:
print(slicey) print(slicey)
quit() quit()
if "Number of Parameters:" in str(out): if "Number of Parameters:" in str(out):
slicey = out[out.find(':')+2:] slicey = out[out.find(":") + 2 :]
try: try:
num_params = int(slicey) num_params = int(slicey)
except: except:
print(slicey) print(slicey)
quit() quit()
if str(out) == str(TEST_SUCCESS_MESSAGE): if str(out) == str(TEST_SUCCESS_MESSAGE):
success=True success = True
return runtime, round(float(int(num_params))/10.0**9,3), success, errs return runtime, round(float(int(num_params)) / 10.0 ** 9, 3), success, errs
def plot(runtimes): def plot(runtimes):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
for distributed_setting in runtimes.keys():
plt.scatter(runtimes[distributed_setting].keys(), runtimes[distributed_setting].values(), label=distributed_setting) for distributed_setting in runtimes.keys():
plt.legend() plt.scatter(
plt.xlabel('Parameters (Billions)') runtimes[distributed_setting].keys(),
plt.ylabel('Training Iteration time (s)') runtimes[distributed_setting].values(),
plt.title(str("GPT Scaling w/ Offloading")) label=distributed_setting,
plt.savefig('offload_gpt_scaling.png') )
plt.close() plt.legend()
if not os.path.exists('/my_workspace/'): plt.xlabel("Parameters (Billions)")
os.system('mkdir /my_workspace/') plt.ylabel("Training Iteration time (s)")
os.system('cp *.png /my_workspace/') plt.title(str("GPT Scaling w/ Offloading"))
plt.savefig("offload_gpt_scaling.png")
plt.close()
if not os.path.exists("/my_workspace/"):
os.system("mkdir /my_workspace/")
os.system("cp *.png /my_workspace/")
def main(): def main():
runtimes = {} runtimes = {}
nlist = list(range(2000,10000,2000)) + list(range(10000,50000,5000)) + list(range(50000,100000,10000)) nlist = (
print("N-List:", nlist) list(range(2000, 10000, 2000))
for data_parr, tens_parr, pipe_parr in [(8,1,1), (4,2,1), (2,1,4), (1,2,4)]: + list(range(10000, 50000, 5000))
for offload in [True, False]: + list(range(50000, 100000, 10000))
dist_setting = 'ddp=' + str(data_parr) + ', tensor_parr=' + str(tens_parr) + ', pipe_parr=' + str(pipe_parr) + ', offload=' + str(offload) )
runtimes[dist_setting] = {} print("N-List:", nlist)
print("Beginning Testing for", dist_setting) for data_parr, tens_parr, pipe_parr in [(8, 1, 1), (4, 2, 1), (2, 1, 4), (1, 2, 4)]:
for n in nlist: for offload in [True, False]:
cmd = "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py" dist_setting = (
cmd += " --micro-batch-size 1 --num-layers " + str(n) + " --hidden-size 128 --num-attention-heads 16" "ddp="
cmd += ' --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size ' + str(tens_parr) + str(data_parr)
cmd += " --pipeline-model-parallel-size " + str(pipe_parr) + (' --cpu-offload' if offload else '') + ", tensor_parr="
print(cmd) + str(tens_parr)
runtime, bill_params, success, errs = run_gpt(cmd) + ", pipe_parr="
if success: + str(pipe_parr)
runtimes[dist_setting][bill_params] = runtime + ", offload="
print(str(runtime) + 's per training iter for', str(bill_params) + 'B parameter GPT-2') + str(offload)
if n >= 10000: )
plot(runtimes) runtimes[dist_setting] = {}
else: print("Beginning Testing for", dist_setting)
print("GPT-2 w/", n, "layers failed using", dist_setting) for n in nlist:
print("Moving on to the next distributed setting...") cmd = "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py"
print("#"*(25)) cmd += (
print() " --micro-batch-size 1 --num-layers "
plot(runtimes) + str(n)
break + " --hidden-size 128 --num-attention-heads 16"
print(runtimes) )
plot(runtimes) cmd += (
" --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size "
+ str(tens_parr)
)
cmd += (
" --pipeline-model-parallel-size "
+ str(pipe_parr)
+ (" --cpu-offload" if offload else "")
)
print(cmd)
runtime, bill_params, success, errs = run_gpt(cmd)
if success:
runtimes[dist_setting][bill_params] = runtime
print(
str(runtime) + "s per training iter for",
str(bill_params) + "B parameter GPT-2",
)
if n >= 10000:
plot(runtimes)
else:
print("GPT-2 w/", n, "layers failed using", dist_setting)
print("Moving on to the next distributed setting...")
print("#" * (25))
print()
plot(runtimes)
break
print(runtimes)
plot(runtimes)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment