Unverified Commit 63d5dd63 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Pipeline Model Parallel (#1202)

* Init apex.ppu (pipeline model parallel utility)

Reference commit:

```
commit 5ab646376d67831601d5552c193241d017f1b35c (HEAD -> main, internal/main)
Merge: 14f2c684 7b293d9b
Author: Mohammad Shoeybi <mshoeybi@nvidia.com>
Date:   Wed Sep 22 22:57:54 2021 -0700

    Merge branch 'add_BOS' into 'main'

    Add Beginning of Sentence token option and adding semaphore while multi-threading to prevent crashes and hangs due to connection keep-alives

    See merge request ADLR/megatron-lm!328
```

* removing get_args and replace import - phase 1

* removing get_args and replace import - phase 2

* move ppu to apex.transformer.pipeline_parallel

* update two __init__.py

* update READMEs

* mpu -> parallel_state & tensor_parallel

* fix

* remove not pipeline files

* separate schedules.py - phase 1

* dissect schedules.py

* data_iterators -> batch

* remove optimizer from forward_backward_step funcs

* init test

* Apply 2 suggestion(s...
parent 3303b3e7
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,9 +14,9 @@
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_src_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 5
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,17 +21,17 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from ..._autocast_utils import _cast_if_autocast_enabled
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import VocabUtility
from apex._autocast_utils import _cast_if_autocast_enabled
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.utils import divide
from apex.transformer.tensor_parallel.mappings import copy_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,10 +14,10 @@
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_world_size
from ..parallel_state import get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
def _reduce(input_):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
......@@ -23,18 +26,17 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from ..parallel_state import get_data_parallel_rank
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .memory import allocate_mem_buff
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.tensor_parallel.memory import allocate_mem_buff
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# Whether apply model parallelsim to checkpointed hidden states.
# Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
......@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=get_tensor_model_parallel_group())
return gathered
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,17 +14,7 @@
# limitations under the License.
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
from apex.transformer.utils import divide
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
......@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
"""Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron arguments."""
import argparse
import os
import torch
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
......@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
......@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
del args.checkpoint_activations
# Set input defaults.
for key in defaults:
......@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
# If we do accumulation and all-reduces in fp32, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
assert args.use_contiguous_buffers_in_local_ddp
# If we use a contiguous buffer to hold main grads, we need to have
# local DDP.
if args.use_contiguous_buffers_in_ddp:
assert args.DDP_impl == 'local'
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
args.use_contiguous_buffers_in_local_ddp = False
if args.dataloader_type is None:
args.dataloader_type = 'single'
......@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
_print_args(args)
return args
......@@ -401,8 +420,20 @@ def _add_training_args(parser):
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......@@ -437,6 +468,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser
......@@ -571,6 +607,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
......@@ -583,9 +622,10 @@ def _add_distributed_args(parser):
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-contiguous-buffers-in-local-ddp',
action='store_false', help='If set, dont use '
'contiguous buffer in local DDP.',
dest='use_contiguous_buffers_in_local_ddp')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,7 +19,7 @@ import numpy
import torch
from apex import transformer
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.testing import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,8 +20,8 @@ import time
import torch
from apex.transformer.tensor_parallel.microbatches import build_num_microbatches_calculator
from apex.transformer.tensor_parallel.tests.arguments import parse_args
from apex.transformer.microbatches import build_num_microbatches_calculator
from .arguments import parse_args
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
......@@ -80,7 +80,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
# _build_num_microbatches_calculator(args)
# if args.vocab_file:
# _ = _build_tokenizer(args)
_set_tensorboard_writer(args)
......
This diff is collapsed.
"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`"""
import torch
from apex.transformer import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
return gathered
# TODO(mkozuki): Rewrite this using `logging`.
def rank_print(msg):
"""Print the given msg with rank information"""
print(
f"tensor rank: {parallel_state.get_tensor_model_parallel_rank()}"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}, "
f"data rank: {parallel_state.get_data_parallel_rank()} | {msg}"
)
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
......@@ -77,28 +77,7 @@ class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
self.skipTest("Skip to save time")
# Megatron style Layer Norm
class TestFusedLayerNormElemWiseMixedDtypes(TestFusedLayerNorm):
def setUp(self):
self.module_cpu_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=True).cpu()
self.module_cuda_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=True).to(device="cuda", dtype=self.dtype)
def test_init_exception(self):
with self.assertRaisesRegex(RuntimeError, "MixedFusedLayerNorm does not support `elementwise_affine = False`"):
apex.normalization.MixedFusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
class TestFusedLayerNormElemWiseMixedDtypesHalf(TestFusedLayerNormElemWiseMixedDtypes):
dtype = torch.half
def test_large_batch(self):
self.skipTest("Skip to save time")
# NOTE (mkozuki): With the larger threshold values, still flaky.
class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMixedDtypesHalf):
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
# Use thresholds larger than those used in pytorch, see
......@@ -106,13 +85,6 @@ class TestFusedLayerNormElemWiseMixedDtypesBFloat16(TestFusedLayerNormElemWiseMi
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# See [BFloat16 Layer Norm flakiness]
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,15 +15,15 @@
import torch
import torch.nn.functional as F
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import IdentityLayer
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.commons import IdentityLayer
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
......@@ -97,6 +97,8 @@ def test_cross_entropy(tensor_model_parallel_size):
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment