Unverified Commit a0ed4151 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Format & Test Refactoring (#1325)

* try PyTorch custom TestCase class

* revert

* initial working example

* update

* data utils

* fix imports

* hardcode backend to nccl

* fix signature

* fix typo

* mapping

* set device

* init

* refactor x entropy

* remove unused import & destroy model parallel

* refactor random

* fix test

* remove migrated tests

* refactor

* init

* separate affine weight init

* init model parallel

* split more

* weight init fix part 1

* use cpu init for consistency btwn native and tensor parallel

* black

* add col parallel

* use a 3D tensor of square matrix for column parallel linear

* skip the failing cases

* migrate layers test

* pipeline parallel forward/backward

* fix typo

* fix typo

* fix

* fix pipeline world size

* black

* rm `run_pipeline_parallel_test` in favor of test_pipeline_parallel_fwd_bwd.py

* stop logging

* set log level

* black

* license and format

* fix

* skip tf32 as matrices are small

* remove potentially inappropriate license

* Apply suggestions from code review

* remove `TODO` comment

* `torch.testing.assert_allclose` -> `torch.testing.assert_close`

* remove comment-outs

* remote unused import

* minor fix
parent f10b4b89
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import torch
......
......@@ -31,7 +31,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
......@@ -65,10 +69,10 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
# 2. Apply the mask.
# 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
......@@ -81,7 +85,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
......@@ -120,7 +126,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
raise RuntimeError(
"both fp16 and bf16 flags cannot be active at the same time."
)
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......
from typing import Optional
import logging
import os
import threading
def get_transformer_logger(name: str) -> logging.Logger:
......@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None:
verbosity
"""
from apex import _library_root_logger
_library_root_logger.setLevel(verbosity)
......@@ -37,7 +37,9 @@ def build_num_microbatches_calculator(
)
if rank == 0:
_logger.info(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get())
"setting number of micro-batches to constant {}".format(
num_microbatches_calculator.get()
)
)
else:
......@@ -54,7 +56,10 @@ def build_num_microbatches_calculator(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size, global_batch_size, batch_size_increment, ramup_samples
start_batch_size,
global_batch_size,
batch_size_increment,
ramup_samples,
),
flush=True,
)
......@@ -91,7 +96,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, (
"global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})".format(global_batch_size, micro_batch_size, data_parallel_size)
" times data parallel size ({})".format(
global_batch_size, micro_batch_size, data_parallel_size
)
)
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
......@@ -131,7 +138,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
self.micro_batch_times_data_parallel_size = (
self.micro_batch_size * self.data_parallel_size
)
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
......@@ -163,15 +172,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment
self.current_global_batch_size = (
self.start_batch_size + steps * self.batch_size_increment
)
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, (
assert (
self.current_global_batch_size
% self.micro_batch_times_data_parallel_size
== 0
), (
"current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})".format(
self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size
self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size,
)
)
self.num_micro_batches = (
self.current_global_batch_size // self.micro_batch_times_data_parallel_size
)
self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size
......@@ -24,6 +24,7 @@ from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
# 'get_num_layers',
......
import warnings
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
__all__ = [
"get_forward_backward_func",
]
class ExperimentalWarning(Warning):
pass
......@@ -21,19 +27,9 @@ def get_forward_backward_func(
if get_num_microbatches() % pipeline_model_parallel_size != 0:
msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise RuntimeError(msg)
warnings.warn(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`: {virtual_pipeline_model_parallel_size}",
ExperimentalWarning
)
forward_backward_func = _forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
__all__ = [
"get_forward_backward_func",
]
......@@ -10,7 +10,9 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from apex.transformer.tensor_parallel.layers import (
set_defaults_if_not_set_tensor_model_parallel_attributes,
)
from apex.transformer.log_util import get_transformer_logger
......@@ -19,7 +21,9 @@ _logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
FwdStepFunc = Callable[
[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]
]
def build_model(
......@@ -49,8 +53,8 @@ def build_model(
the list has multiple models, otherwise one.
"""
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
virtual_pipeline_model_parallel_size is not None
parallel_state.get_pipeline_model_parallel_world_size() > 1
and virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(virtual_pipeline_model_parallel_size):
......@@ -60,10 +64,9 @@ def build_model(
# Set pre_process and post_process only after virtual rank is set.
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
cur_kwargs.update(
{"pre_process": pre_process, "post_process": post_process,}
)
this_model = model_provider_func(*cur_args, **cur_kwargs)
model.append(this_model)
else:
......@@ -72,10 +75,9 @@ def build_model(
if model_type == ModelType.encoder_or_decoder:
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
cur_kwargs.update(
{"pre_process": pre_process, "post_process": post_process,}
)
model = model_provider_func(*cur_args, **cur_kwargs)
elif model_type == ModelType.encoder_and_decoder:
pre_process = parallel_state.is_pipeline_first_stage()
......@@ -94,12 +96,14 @@ def build_model(
post_process = rank == (split_rank - 1) or rank == (world_size - 1)
add_encoder = parallel_state.is_pipeline_stage_before_split()
add_decoder = parallel_state.is_pipeline_stage_after_split()
cur_kwargs.update({
cur_kwargs.update(
{
"pre_process": pre_process,
"post_process": post_process,
"add_encoder": add_encoder,
"add_decoder": add_decoder,
})
}
)
model = model_provider_func(*cur_args, **cur_kwargs)
model.model_type = model_type
......@@ -115,7 +119,10 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
if (
parallel_state.model_parallel_is_initialized()
and parallel_state.get_data_parallel_rank() == 0
):
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
......@@ -143,7 +150,12 @@ def build_model(
def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
assert isinstance(model, list)
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
return sum(
[
sum([p.nelement() for p in model_module.parameters()])
for model_module in model
]
)
def _get_params_for_weight_decay_optimization(
......@@ -156,28 +168,36 @@ def _get_params_for_weight_decay_optimization(
Layernorms and biases will have no weight decay but the rest will.
"""
modules = listify_model(model)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module in modules:
for module_ in module.modules():
if isinstance(module_, no_weight_decay_modules):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)
return weight_decay_params, no_weight_decay_params
def free_output_tensor(
output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
deallocate_pipeline_outputs: bool = False
deallocate_pipeline_outputs: bool = False,
) -> None:
"""Pseudo-free the output tensor's `.data` field.
......@@ -202,9 +222,15 @@ def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not.
"""
assert output.numel() == 1, "output should be pseudo-freed in schedule, to optimize memory consumption"
assert isinstance(output, torch.Tensor), "output == {}.".format(type(output).__name__)
assert isinstance(grad_output, (torch.Tensor, type(None))), "grad_outptu == {}.".format(type(grad_output).__name__)
assert (
output.numel() == 1
), "output should be pseudo-freed in schedule, to optimize memory consumption"
assert isinstance(output, torch.Tensor), "output == {}.".format(
type(output).__name__
)
assert isinstance(
grad_output, (torch.Tensor, type(None))
), "grad_outptu == {}.".format(type(grad_output).__name__)
# Handle scalar output
if grad_output is None:
......@@ -278,7 +304,10 @@ def forward_step(
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if parallel_state.is_pipeline_stage_after_split() and model_type == ModelType.encoder_and_decoder:
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
......@@ -343,9 +372,9 @@ def backward_step(
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1 and
parallel_state.is_pipeline_stage_after_split() and
model_type == ModelType.encoder_and_decoder
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
......
......@@ -37,7 +37,7 @@ def forward_backward_no_pipelining(
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
custom_sync_context_handler = None,
custom_sync_context_handler=None,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
......@@ -98,7 +98,13 @@ def forward_backward_no_pipelining(
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
......@@ -115,6 +121,12 @@ def forward_backward_no_pipelining(
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
)
return losses_reduced
......@@ -71,12 +71,18 @@ def _forward_backward_pipelining_with_interleaving(
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
output_tensors: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
input_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
output_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
losses_reduced: List[Union[None, torch.Tensor]] = []
if not forward_only:
output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [[] for _ in range(num_model_chunks)]
output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks)
]
pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank()
......@@ -97,7 +103,9 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches = (
pipeline_parallel_size - pipeline_parallel_rank - 1
) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches
......@@ -114,7 +122,9 @@ def _forward_backward_pipelining_with_interleaving(
def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
microbatch_id_in_group = microbatch_id % (
pipeline_parallel_size * num_model_chunks
)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
......@@ -129,10 +139,9 @@ def _forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if (
parallel_state.is_pipeline_first_stage() and
len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id])
):
if parallel_state.is_pipeline_first_stage() and len(
input_tensors[model_chunk_id]
) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
......@@ -169,7 +178,14 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs)
input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
return input_tensor_grad
......@@ -177,7 +193,9 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes.
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)
)
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
......@@ -191,7 +209,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
_logger.debug(f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}")
_logger.debug(
f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}"
)
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
......@@ -200,7 +220,11 @@ def _forward_backward_pipelining_with_interleaving(
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
......@@ -221,7 +245,11 @@ def _forward_backward_pipelining_with_interleaving(
else:
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype)
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
......@@ -251,7 +279,9 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
_logger.debug(f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}")
_logger.debug(
f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}"
)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
......@@ -267,7 +297,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
next_forward_model_chunk_id = get_model_chunk_id(
forward_k + 1, forward=True
)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
......@@ -279,7 +311,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
next_backward_model_chunk_id = get_model_chunk_id(
backward_k + 1, forward=False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
......@@ -314,9 +348,13 @@ def _forward_backward_pipelining_with_interleaving(
_logger.info("Cooldown phase")
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)
)
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
_logger.debug(
f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})"
)
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
......@@ -327,7 +365,11 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
)
)
return losses_reduced
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, "{} has data type {} which " "is different than {}".format(
key, data[key].dtype, target_dtype
assert data[key].dtype == target_dtype, (
"{} has data type {} which "
"is different than {}".format(key, data[key].dtype, target_dtype)
)
......@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(),
sizes_cuda,
get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group(),
)
# Move back to cpu and unpack.
......@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype):
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
flatten_data = torch.cat(
[data[key].contiguous().view(-1) for key in keys], dim=0
).cuda()
else:
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
flatten_data = torch.empty(
total_numel, device=torch.cuda.current_device(), dtype=datatype
)
# Broadcast
torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(),
flatten_data,
get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group(),
)
# Unpack
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -26,10 +26,18 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.utils import divide
from apex.transformer.tensor_parallel.mappings import copy_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
reduce_from_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.mappings import (
scatter_to_tensor_model_parallel_region,
)
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger
......@@ -53,9 +61,9 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0
)
return (
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
) or (get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
......@@ -89,7 +97,9 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride)
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
with get_cuda_rng_tracker().fork():
init_method(weight)
......@@ -114,16 +124,22 @@ def _initialize_affine_weight_cpu(
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride)
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
# Initialize master weight
master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
master_weight = torch.empty(
output_size, input_size, dtype=torch.float, requires_grad=False
)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
weight_list = torch.split(
master_weight, per_partition_per_stride_size, dim=partition_dim
)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
......@@ -147,7 +163,13 @@ class VocabParallelEmbedding(torch.nn.Module):
"""
def __init__(
self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False,
self,
num_embeddings,
embedding_dim,
init_method=init.xavier_normal_,
*,
params_dtype=torch.float32,
use_cpu_initialization=False,
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
......@@ -162,18 +184,34 @@ class VocabParallelEmbedding(torch.nn.Module):
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
self.num_embeddings_per_partition = (
self.vocab_end_index - self.vocab_start_index
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize.
if use_cpu_initialization:
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
dtype=params_dtype,
)
)
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method,
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=params_dtype,
)
else:
......@@ -185,12 +223,16 @@ class VocabParallelEmbedding(torch.nn.Module):
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=1
)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
input_mask = (input_ < self.vocab_start_index) | (
input_ >= self.vocab_end_index
)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
......@@ -216,8 +258,11 @@ class VocabParallelEmbedding(torch.nn.Module):
class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce):
def forward(
ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
......@@ -233,17 +278,23 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
input, grad_output, weight.main_grad
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
......@@ -255,21 +306,22 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
def linear_with_grad_accumulation_and_async_allreduce(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce,
):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce)
args = _cast_if_autocast_enabled(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduce.apply(*args)
class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce):
def forward(
ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
......@@ -285,17 +337,23 @@ class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input, grad_output, weight.main_grad)
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
input, grad_output, weight.main_grad
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
......@@ -307,13 +365,11 @@ class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function
def linear_with_grad_accumulation_and_async_allreduce_in16bit(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce,
):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce)
args = _cast_if_autocast_enabled(
input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce
)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduceIn16Bit.apply(*args)
......@@ -382,7 +438,11 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype))
self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=params_dtype
)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
......@@ -403,14 +463,22 @@ class ColumnParallelLinear(torch.nn.Module):
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype))
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=params_dtype)
)
else:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
......@@ -420,8 +488,8 @@ class ColumnParallelLinear(torch.nn.Module):
self.register_parameter("bias", None)
self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and
world_size > 1)
not no_async_tensor_model_parallel_allreduce and world_size > 1
)
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
# Basically, apex.transformer module users are expected to install APEX's
......@@ -429,6 +497,7 @@ class ColumnParallelLinear(torch.nn.Module):
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
import warnings
warnings.warn(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
......@@ -438,7 +507,11 @@ class ColumnParallelLinear(torch.nn.Module):
gradient_accumulation_fusion = False
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce
self._forward_impl = (
linear_with_grad_accumulation_and_async_allreduce_in16bit
if accumulation_in_fp16
else linear_with_grad_accumulation_and_async_allreduce
)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
......@@ -450,8 +523,12 @@ class ColumnParallelLinear(torch.nn.Module):
input_parallel = input_
# Matrix multiply.
output_parallel = self._forward_impl(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce)
input_parallel,
self.weight,
bias,
self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce,
)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
......@@ -522,7 +599,11 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype))
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=params_dtype
)
)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
......@@ -543,13 +624,19 @@ class RowParallelLinear(torch.nn.Module):
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
else:
self.bias = Parameter(
torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=params_dtype)
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
# Always initialize bias to zero.
with torch.no_grad():
......
......@@ -66,7 +66,9 @@ def _gather(input_):
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
torch.distributed.all_gather(
tensor_list, input_, group=get_tensor_model_parallel_group()
)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
......
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -49,13 +50,20 @@ class MemoryBuffer:
element_size = torch.tensor([], dtype=dtype).element_size()
print(
"> building the {} memory buffer with {} num elements "
"and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024),
"and {} dtype ({:.1f} MB)...".format(
name, numel, dtype, numel * element_size / 1024 / 1024
),
flush=True,
)
self.name = name
self.numel = numel
self.dtype = dtype
self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False)
self.data = torch.empty(
self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Index tracking the start of the free memory.
self._start = 0
......@@ -81,13 +89,17 @@ class MemoryBuffer:
def add(self, tensor):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format(
assert (
tensor.dtype == self.dtype
), "Input tensor type {} different from buffer type {}".format(
tensor.dtype, self.dtype
)
# Number of elements of the input tensor.
tensor_numel = torch.numel(tensor)
new_start = self._start + tensor_numel
assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format(
assert (
new_start <= self.numel
), "Not enough memory left in the buffer ({} > {})".format(
tensor_numel, self.numel - self._start
)
# New tensor is a view into the memory.
......@@ -124,7 +136,8 @@ class RingMemBuffer:
def __init__(self, name, num_buffers, numel, dtype, track_usage):
self.num_buffers = num_buffers
self.buffers = [
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) for i in range(num_buffers)
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage)
for i in range(num_buffers)
]
self._index = -1
......
......@@ -53,8 +53,15 @@ def init_checkpointed_activations_memory_buffer(
):
"""Initializ the memory buffer for the checkpointed activations."""
per_layer = micro_batch_size * max_position_embeddings * hidden_size // tensor_model_parallel_size
assert num_layers % checkpoint_num_layers == 0, "number of layers is not divisible by checkpoint-num-layers"
per_layer = (
micro_batch_size
* max_position_embeddings
* hidden_size
// tensor_model_parallel_size
)
assert (
num_layers % checkpoint_num_layers == 0
), "number of layers is not divisible by checkpoint-num-layers"
num_checkpointer_layers = num_layers // checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
......@@ -217,7 +224,9 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(
_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed
)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
......@@ -255,7 +264,10 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible")
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
......@@ -284,7 +296,10 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs
)
return (None,) + grads
......
......@@ -43,7 +43,9 @@ class VocabUtility:
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
......@@ -51,4 +53,6 @@ class VocabUtility:
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
......@@ -21,7 +21,9 @@ import torch
import torch.nn as nn
from apex import transformer
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import (
average_losses_across_data_parallel_group,
)
from apex.transformer.testing import global_vars
......@@ -30,7 +32,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
......@@ -40,16 +41,22 @@ class MyLayer(nn.Module):
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool = False, post_process: bool = False) -> None:
class MyModel(nn.Module):
def __init__(
self, hidden_size: int, pre_process: bool = False, post_process: bool = False
) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(hidden_size=hidden_size, pre_process=pre_process, post_process=post_process)
self.layer = MyLayer(
hidden_size=hidden_size, pre_process=pre_process, post_process=post_process
)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
def set_input_tensor(
self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]
) -> None:
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.input_tensor = input_tensor[0]
......@@ -81,7 +88,8 @@ def fwd_step_func(batch, model):
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'avg': averaged_loss}
return loss, {"avg": averaged_loss}
return y, loss_func
......@@ -102,7 +110,7 @@ def set_random_seed(seed):
transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
def initialize_distributed(backend="nccl"):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
......@@ -113,11 +121,13 @@ def initialize_distributed(backend='nccl'):
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
print('> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
print(
"> initializing torch.distributed with local rank: {}, "
"rank: {}, world size: {}".format(local_rank, rank, world_size)
)
# Set the device id.
device = rank % torch.cuda.device_count()
......@@ -126,22 +136,20 @@ def initialize_distributed(backend='nccl'):
torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
init_method = "tcp://"
master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method)
backend=backend, world_size=world_size, rank=rank, init_method=init_method
)
def print_separator(message):
torch.distributed.barrier()
filler_len = (78 - len(message)) // 2
filler = '-' * filler_len
string = '\n' + filler + ' {} '.format(message) + filler
filler = "-" * filler_len
string = "\n" + filler + " {} ".format(message) + filler
if torch.distributed.get_rank() == 0:
print(string, flush=True)
torch.distributed.barrier()
import sys
import torch
from torch import distributed as dist
from torch.testing._internal import common_utils
from torch.testing._internal import common_distributed
class DistributedTestBase(common_distributed.MultiProcessTestCase):
BACKEND_NCCL = "nccl"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self) -> None:
super().tearDown()
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)
@property
def init_method(self):
return f"{common_utils.FILE_SCHEMA}{self.file_name}"
@classmethod
def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name)
self.assertTrue(torch.cuda.is_available())
self.rank = rank
self.file_name = file_name
print(f"[dist init] rank = {self.rank}, world_size = {self.world_size}")
try:
dist.init_process_group(
init_method=self.init_method,
backend=DistributedTestBase.BACKEND_NCCL,
world_size=int(self.world_size),
rank=self.rank,
)
except RuntimeError as e:
if "recompile" in e.args[0]:
print(f"Backend of {DistributedTestBase.BACKEND_NCCL} not available")
sys.exit(0)
raise
torch.cuda.set_device(self.rank % torch.cuda.device_count())
dist.barrier()
self.run_test(test_name, pipe)
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
......@@ -6,7 +6,9 @@ from apex.transformer import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
......@@ -19,7 +21,9 @@ def divide(numerator, denominator):
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
......@@ -30,7 +34,14 @@ def gather_split_1d_tensor(tensor):
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(
chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
import subprocess
import os
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
def run_gpt(cmd):
args = list(cmd.split(' '))
args = list(cmd.split(" "))
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outs, errs = p.communicate()
outs = list(str((outs).decode('utf-8')).splitlines())
outs = list(str((outs).decode("utf-8")).splitlines())
success = False
runtime = 0
num_params = 0
for out in outs:
out=str(out)
out = str(out)
if "Average Iteration Time:" in str(out):
slicey = out[out.find(':')+2:]
slicey = out[out.find(":") + 2 :]
try:
runtime = float(slicey)
except:
print(slicey)
quit()
if "Number of Parameters:" in str(out):
slicey = out[out.find(':')+2:]
slicey = out[out.find(":") + 2 :]
try:
num_params = int(slicey)
except:
print(slicey)
quit()
if str(out) == str(TEST_SUCCESS_MESSAGE):
success=True
return runtime, round(float(int(num_params))/10.0**9,3), success, errs
success = True
return runtime, round(float(int(num_params)) / 10.0 ** 9, 3), success, errs
def plot(runtimes):
import matplotlib.pyplot as plt
for distributed_setting in runtimes.keys():
plt.scatter(runtimes[distributed_setting].keys(), runtimes[distributed_setting].values(), label=distributed_setting)
plt.scatter(
runtimes[distributed_setting].keys(),
runtimes[distributed_setting].values(),
label=distributed_setting,
)
plt.legend()
plt.xlabel('Parameters (Billions)')
plt.ylabel('Training Iteration time (s)')
plt.xlabel("Parameters (Billions)")
plt.ylabel("Training Iteration time (s)")
plt.title(str("GPT Scaling w/ Offloading"))
plt.savefig('offload_gpt_scaling.png')
plt.savefig("offload_gpt_scaling.png")
plt.close()
if not os.path.exists('/my_workspace/'):
os.system('mkdir /my_workspace/')
os.system('cp *.png /my_workspace/')
if not os.path.exists("/my_workspace/"):
os.system("mkdir /my_workspace/")
os.system("cp *.png /my_workspace/")
def main():
runtimes = {}
nlist = list(range(2000,10000,2000)) + list(range(10000,50000,5000)) + list(range(50000,100000,10000))
nlist = (
list(range(2000, 10000, 2000))
+ list(range(10000, 50000, 5000))
+ list(range(50000, 100000, 10000))
)
print("N-List:", nlist)
for data_parr, tens_parr, pipe_parr in [(8,1,1), (4,2,1), (2,1,4), (1,2,4)]:
for data_parr, tens_parr, pipe_parr in [(8, 1, 1), (4, 2, 1), (2, 1, 4), (1, 2, 4)]:
for offload in [True, False]:
dist_setting = 'ddp=' + str(data_parr) + ', tensor_parr=' + str(tens_parr) + ', pipe_parr=' + str(pipe_parr) + ', offload=' + str(offload)
dist_setting = (
"ddp="
+ str(data_parr)
+ ", tensor_parr="
+ str(tens_parr)
+ ", pipe_parr="
+ str(pipe_parr)
+ ", offload="
+ str(offload)
)
runtimes[dist_setting] = {}
print("Beginning Testing for", dist_setting)
for n in nlist:
cmd = "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py"
cmd += " --micro-batch-size 1 --num-layers " + str(n) + " --hidden-size 128 --num-attention-heads 16"
cmd += ' --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size ' + str(tens_parr)
cmd += " --pipeline-model-parallel-size " + str(pipe_parr) + (' --cpu-offload' if offload else '')
cmd += (
" --micro-batch-size 1 --num-layers "
+ str(n)
+ " --hidden-size 128 --num-attention-heads 16"
)
cmd += (
" --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size "
+ str(tens_parr)
)
cmd += (
" --pipeline-model-parallel-size "
+ str(pipe_parr)
+ (" --cpu-offload" if offload else "")
)
print(cmd)
runtime, bill_params, success, errs = run_gpt(cmd)
if success:
runtimes[dist_setting][bill_params] = runtime
print(str(runtime) + 's per training iter for', str(bill_params) + 'B parameter GPT-2')
print(
str(runtime) + "s per training iter for",
str(bill_params) + "B parameter GPT-2",
)
if n >= 10000:
plot(runtimes)
else:
print("GPT-2 w/", n, "layers failed using", dist_setting)
print("Moving on to the next distributed setting...")
print("#"*(25))
print("#" * (25))
print()
plot(runtimes)
break
print(runtimes)
plot(runtimes)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment