"docs/vscode:/vscode.git/clone" did not exist on "079f29bcc35ab8fec6d1b15aa9e4ba4b56dcb478"
Unverified Commit 63d5dd63 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Pipeline Model Parallel (#1202)

* Init apex.ppu (pipeline model parallel utility)

Reference commit:

```
commit 5ab646376d67831601d5552c193241d017f1b35c (HEAD -> main, internal/main)
Merge: 14f2c684 7b293d9b
Author: Mohammad Shoeybi <mshoeybi@nvidia.com>
Date:   Wed Sep 22 22:57:54 2021 -0700

    Merge branch 'add_BOS' into 'main'

    Add Beginning of Sentence token option and adding semaphore while multi-threading to prevent crashes and hangs due to connection keep-alives

    See merge request ADLR/megatron-lm!328
```

* removing get_args and replace import - phase 1

* removing get_args and replace import - phase 2

* move ppu to apex.transformer.pipeline_parallel

* update two __init__.py

* update READMEs

* mpu -> parallel_state & tensor_parallel

* fix

* remove not pipeline files

* separate schedules.py - phase 1

* dissect schedules.py

* data_iterators -> batch

* remove optimizer from forward_backward_step funcs

* init test

* Apply 2 suggestion(s...
parent 3303b3e7
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
import torch import torch
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_src_rank from apex.transformer.parallel_state import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 5 _MAX_DATA_DIM = 5
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,17 +21,17 @@ import torch.nn.functional as F ...@@ -21,17 +21,17 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..._autocast_utils import _cast_if_autocast_enabled from apex._autocast_utils import _cast_if_autocast_enabled
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region from apex.transformer.utils import divide
from .mappings import gather_from_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
from .random import get_cuda_rng_tracker from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from .utils import divide from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from .utils import VocabUtility from apex.transformer.tensor_parallel.utils import VocabUtility
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
import torch import torch
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_world_size from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
def _reduce(input_): def _reduce(input_):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
...@@ -23,18 +26,17 @@ from torch import _C ...@@ -23,18 +26,17 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from ..parallel_state import get_data_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.tensor_parallel.memory import allocate_mem_buff
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from ..parallel_state import get_tensor_model_parallel_world_size from apex.transformer.utils import gather_split_1d_tensor
from .memory import allocate_mem_buff
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" _MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# Whether apply model parallelsim to checkpointed hidden states. # Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
...@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=get_tensor_model_parallel_group())
return gathered
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with """This function is adapted from torch.utils.checkpoint with
two main changes: two main changes:
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,17 +14,7 @@ ...@@ -14,17 +14,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
from apex.transformer.utils import divide
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
...@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= ...@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
class VocabUtility: class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the """Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank` first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)""" partition: Note that indices in [fist, last)"""
@staticmethod @staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Megatron arguments.""" """Megatron arguments."""
import argparse import argparse
import os import os
import torch import torch
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
"""Parse all arguments.""" """Parse all arguments."""
...@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size, args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True) args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments # Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \ assert args.batch_size is None, '--batch-size argument is no longer ' \
...@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
del args.checkpoint_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
...@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have # If we do accumulation and all-reduces in fp32, we need to have local DDP
# local DDP and we should set the use-contiguous-buffers-in-ddp. # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True assert args.use_contiguous_buffers_in_local_ddp
# If we use a contiguous buffer to hold main grads, we need to have # For torch DDP, we do not use contiguous buffer
# local DDP. if args.DDP_impl == 'torch':
if args.use_contiguous_buffers_in_ddp: args.use_contiguous_buffers_in_local_ddp = False
assert args.DDP_impl == 'local'
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
_print_args(args) _print_args(args)
return args return args
...@@ -401,8 +420,20 @@ def _add_training_args(parser): ...@@ -401,8 +420,20 @@ def _add_training_args(parser):
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--activations-checkpoint-method', type=str, default=None,
help='chunk size (number of layers) for checkpointing.') choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
...@@ -437,6 +468,11 @@ def _add_training_args(parser): ...@@ -437,6 +468,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None, group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser return parser
...@@ -571,6 +607,9 @@ def _add_distributed_args(parser): ...@@ -571,6 +607,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
...@@ -583,9 +622,10 @@ def _add_distributed_args(parser): ...@@ -583,9 +622,10 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', group.add_argument('--no-contiguous-buffers-in-local-ddp',
help='If set, use contiguous buffer in DDP. Note that ' action='store_false', help='If set, dont use '
'this option only works woth local DDP.' ) 'contiguous buffer in local DDP.',
dest='use_contiguous_buffers_in_local_ddp')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,7 +19,7 @@ import numpy ...@@ -19,7 +19,7 @@ import numpy
import torch import torch
from apex import transformer from apex import transformer
from apex.transformer.tensor_parallel.tests import global_vars from apex.transformer.testing import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)" TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,8 +20,8 @@ import time ...@@ -20,8 +20,8 @@ import time
import torch import torch
from apex.transformer.tensor_parallel.microbatches import build_num_microbatches_calculator from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.tensor_parallel.tests.arguments import parse_args from .arguments import parse_args
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
...@@ -80,7 +80,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -80,7 +80,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults, defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args) # _build_num_microbatches_calculator(args)
# if args.vocab_file: # if args.vocab_file:
# _ = _build_tokenizer(args) # _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
......
This diff is collapsed.
"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`"""
import torch
from apex.transformer import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
return gathered
# TODO(mkozuki): Rewrite this using `logging`.
def rank_print(msg):
"""Print the given msg with rank information"""
print(
f"tensor rank: {parallel_state.get_tensor_model_parallel_rank()}"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}, "
f"data rank: {parallel_state.get_data_parallel_rank()} | {msg}"
)
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
...@@ -77,28 +77,7 @@ class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise): ...@@ -77,28 +77,7 @@ class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
self.skipTest("Skip to save time") self.skipTest("Skip to save time")
# Megatron style Layer Norm class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
class TestFusedLayerNormElemWiseMixedDtypes(TestFusedLayerNorm):
def setUp(self):
self.module_cpu_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=True).cpu()
self.module_cuda_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=True).to(device="cuda", dtype=self.dtype)
def test_init_exception(self):
with self.assertRaisesRegex(RuntimeError, "MixedFusedLayerNorm does not support `elementwise_affine = False`"):
apex.normalization.MixedFusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
class TestFusedLayerNormElemWiseMixedDtypesHalf(TestFusedLayerNormElemWiseMixedDtypes):
dtype = torch.half
def test_large_batch(self):
self.skipTest("Skip to save time")
# NOTE (mkozuki): With the larger threshold values, still flaky.
class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMixedDtypesHalf):
dtype = torch.bfloat16 dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness] # NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
# Use thresholds larger than those used in pytorch, see # Use thresholds larger than those used in pytorch, see
...@@ -106,13 +85,6 @@ class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMi ...@@ -106,13 +85,6 @@ class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMi
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# See [BFloat16 Layer Norm flakiness]
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self): def test_large_batch(self):
self.skipTest("Skip to save time") self.skipTest("Skip to save time")
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,15 +15,15 @@ ...@@ -15,15 +15,15 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import IdentityLayer
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.tensor_parallel.tests import global_vars from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.commons import IdentityLayer
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables() global_vars.set_global_variables()
...@@ -97,6 +97,8 @@ def test_cross_entropy(tensor_model_parallel_size): ...@@ -97,6 +97,8 @@ def test_cross_entropy(tensor_model_parallel_size):
if __name__ == '__main__': if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment