Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,17 +21,17 @@ import torch.nn.functional as F ...@@ -21,17 +21,17 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..._autocast_utils import _cast_if_autocast_enabled from apex._autocast_utils import _cast_if_autocast_enabled
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region from apex.transformer.utils import divide
from .mappings import gather_from_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region
from .random import get_cuda_rng_tracker from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from .utils import divide from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from .utils import VocabUtility from apex.transformer.tensor_parallel.utils import VocabUtility
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
import torch import torch
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_world_size from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim
def _reduce(input_): def _reduce(input_):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
...@@ -23,18 +26,17 @@ from torch import _C ...@@ -23,18 +26,17 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from ..parallel_state import get_data_parallel_rank from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_group from apex.transformer.tensor_parallel.memory import allocate_mem_buff
from ..parallel_state import get_tensor_model_parallel_rank from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from ..parallel_state import get_tensor_model_parallel_world_size from apex.transformer.utils import gather_split_1d_tensor
from .memory import allocate_mem_buff
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" _MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# Whether apply model parallelsim to checkpointed hidden states. # Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
...@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -108,26 +110,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=get_tensor_model_parallel_group())
return gathered
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -238,6 +220,7 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with """This function is adapted from torch.utils.checkpoint with
two main changes: two main changes:
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,17 +14,7 @@ ...@@ -14,17 +14,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
from apex.transformer.utils import divide
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
...@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= ...@@ -48,9 +38,9 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
class VocabUtility: class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the """Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank` first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)""" partition: Note that indices in [fist, last)"""
@staticmethod @staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Megatron arguments.""" """Megatron arguments."""
import argparse import argparse
import os import os
import torch import torch
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
"""Parse all arguments.""" """Parse all arguments."""
...@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -79,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size, args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True) args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments # Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \ assert args.batch_size is None, '--batch-size argument is no longer ' \
...@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -90,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
del args.checkpoint_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
...@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -147,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have # If we do accumulation and all-reduces in fp32, we need to have local DDP
# local DDP and we should set the use-contiguous-buffers-in-ddp. # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True assert args.use_contiguous_buffers_in_local_ddp
# If we use a contiguous buffer to hold main grads, we need to have # For torch DDP, we do not use contiguous buffer
# local DDP. if args.DDP_impl == 'torch':
if args.use_contiguous_buffers_in_ddp: args.use_contiguous_buffers_in_local_ddp = False
assert args.DDP_impl == 'local'
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -233,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
_print_args(args) _print_args(args)
return args return args
...@@ -401,8 +420,20 @@ def _add_training_args(parser): ...@@ -401,8 +420,20 @@ def _add_training_args(parser):
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--activations-checkpoint-method', type=str, default=None,
help='chunk size (number of layers) for checkpointing.') choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
...@@ -437,6 +468,11 @@ def _add_training_args(parser): ...@@ -437,6 +468,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None, group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser return parser
...@@ -571,6 +607,9 @@ def _add_distributed_args(parser): ...@@ -571,6 +607,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
...@@ -583,9 +622,10 @@ def _add_distributed_args(parser): ...@@ -583,9 +622,10 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', group.add_argument('--no-contiguous-buffers-in-local-ddp',
help='If set, use contiguous buffer in DDP. Note that ' action='store_false', help='If set, dont use '
'this option only works woth local DDP.' ) 'contiguous buffer in local DDP.',
dest='use_contiguous_buffers_in_local_ddp')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,17 +14,53 @@ ...@@ -14,17 +14,53 @@
# limitations under the License. # limitations under the License.
import os import os
import random import random
from typing import Optional, Union, List
import numpy import numpy
import torch import torch
import torch.nn as nn
from apex import transformer from apex import transformer
from apex.transformer.tensor_parallel.tests import global_vars from apex.transformer.testing import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)" TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool = False, post_process: bool = False) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(hidden_size=hidden_size, pre_process=pre_process, post_process=post_process)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
return self.layer(x)
return self.layer(self.input_tensor)
def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
class IdentityLayer(torch.nn.Module): class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,8 +20,8 @@ import time ...@@ -20,8 +20,8 @@ import time
import torch import torch
from apex.transformer.tensor_parallel.microbatches import build_num_microbatches_calculator from apex.transformer.microbatches import build_num_microbatches_calculator
from apex.transformer.tensor_parallel.tests.arguments import parse_args from .arguments import parse_args
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
...@@ -37,17 +37,27 @@ def get_args(): ...@@ -37,17 +37,27 @@ def get_args():
return _GLOBAL_ARGS return _GLOBAL_ARGS
def get_num_microbatches(): def get_num_microbatches() -> int:
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size(): def get_current_global_batch_size() -> int:
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True): def update_num_microbatches(consumed_samples: int, *, consistency_check: bool = True) -> None:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, """Update the number of microbatches upon the number of consumed samples.
consistency_check)
.. note::
This function has no effect unless ``rampup_batch_size`` is set.
Args:
consumed_samples: The number of consumed samples so far. Basically this is equal to
:math:`num_iter * global_batch_size`.
consistency_check: If :obj:`True`, sanity checks the consumed samples, i.e., check if
``consumed_samples`` is divisible by :math:`micro_batch_size \times data_parallel_size`.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
# def get_tokenizer(): # def get_tokenizer():
...@@ -80,7 +90,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -80,7 +90,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults, defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args) # _build_num_microbatches_calculator(args)
# if args.vocab_file: # if args.vocab_file:
# _ = _build_tokenizer(args) # _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
......
import torch
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from apex.transformer.testing.global_vars import get_args
from .standalone_gpt import get_language_model, get_linear_layer, init_method_normal, parallel_lm_logits, scaled_init_method_normal
from .standalone_gpt import MegatronModule
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
# TODO: do we need this?
# mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
logit_weights,
fp16_lm_cross_entropy):
# Output.
lm_logits = lm_head(
lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
else:
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
super(BertModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.post_process:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def bert_model_provider(pre_process=True, post_process=True):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process)
return model
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import enum
import math
import torch
import torch.nn.functional as F
import apex.transformer.utils
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer.functional import FusedScaleMaskSoftmax
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.testing.global_vars import get_args
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="",
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
def word_embeddings_weight(self):
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) or \
parallel_state.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception("word_embeddings_weight() called for last "
"stage, but share_word_embeddings is false")
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception("initialize_word_embeddings() was called but "
"share_word_embeddings is false")
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if parallel_state.is_pipeline_last_stage():
assert not parallel_state.is_pipeline_first_stage()
self._word_embeddings_for_head_key = "word_embeddings_for_head"
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
use_cpu_initialization=args.use_cpu_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and \
not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and \
parallel_state.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=parallel_state.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=parallel_state.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
super().__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization)
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
):
super().__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = apex.transformer.utils.divide(projection_size, world_size)
self.hidden_size_per_attention_head = apex.transformer.utils.divide(projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
)
# Inference key-value memory
self.inference_key_memory = None
self.inference_value_memory = None
self.inference_current_sequence_len = 0
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if set_inference_key_value_memory:
assert inference_max_sequence_len and inference_max_sequence_len > 0
self.inference_key_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1))
self.inference_value_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1))
self.inference_current_sequence_len = 0
# Some consistency check.
if inference_max_sequence_len:
assert self.inference_current_sequence_len < self.inference_key_memory.size(0)
assert inference_max_sequence_len == self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if not inference_max_sequence_len:
self.inference_key_memory = None
self.inference_value_memory = None
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ===================================================
# Adjust key, value, and attention mask for inference
# ===================================================
if inference_max_sequence_len:
# Adjust the range variables.
start = self.inference_current_sequence_len
self.inference_current_sequence_len += key_layer.size(0)
end = self.inference_current_sequence_len
# Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
# Adjust attention mask
attention_mask = attention_mask[..., start:end, :end]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
@torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
):
args = get_args()
super().__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn
)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
layernorm_output,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(
self,
init_method,
output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=True,
):
super().__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpointing flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
num_layers = args.num_layers
# Number of layers.
assert (
num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0
), "num_layers must be divisible by pipeline_model_parallel_size"
self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size()
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
"num_layers_per_stage must be divisible by " "virtual_pipeline_model_parallel_size"
)
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size
) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
def distribute_checkpointed_activations_helper(layer_number):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = layer_number > 0
is_first_pipeline_stage = parallel_state.get_pipeline_model_parallel_rank() == 0
return self.distribute_checkpointed_activations and (
not_first_layer_in_pipeline_stage or is_first_pipeline_stage
)
if self.activations_checkpoint_method == "uniform":
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == "block":
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
distribute_checkpointed_activations_helper(l),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
else:
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# Checks.
if inference_max_sequence_len:
assert self.activations_checkpoint_method is None, "inference does not work with activation checkpointing"
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
num_tokentypes,
add_pooler,
encoder_attn_mask_type,
init_method=None,
scaled_init_method=None,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
):
"""Build language model and return along with the key to save."""
args = get_args()
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
pre_process=pre_process,
post_process=post_process,
)
# key used for checkpoints.
language_model_key = "language_model"
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0
):
super().__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method,
use_cpu_initialization=args.use_cpu_initialization
)
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
print("FINISH WORD EMBEDDING", self.word_embeddings)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "word_embeddings" in key:
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "position_embeddings" in key:
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if "tokentype_embeddings" in key:
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print(
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
):
super().__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
self._embedding_key = "embedding"
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = "encoder"
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert (
args.pipeline_model_parallel_size == 1
), "pipeline parallelism is not supported in the presence of decoder"
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._decoder_key = "decoder"
else:
self.decoder = None
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = "pooler"
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with both encoder and decoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, "input_tensor should only be length 1 for stage with only encoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception("input_tensor must have either length 1 or 2")
else:
raise Exception("Stage must have at least either encoder or decoder")
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids)
else:
encoder_input = None
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids, dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(
decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_encoder:
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_decoder:
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "_embeddings" in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder.
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif "transformer" in state_dict:
state_dict_ = state_dict["transformer"]
else:
# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "transformer." in key:
state_dict_[key.split("transformer.")[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if ".attention." in key:
state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process:
if self.add_pooler:
assert "pooler" in state_dict, "could not find data for pooler in the checkpoint"
self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
# Decoder.
if self.add_decoder:
assert "decoder" in state_dict, "could not find data for pooler in the checkpoint"
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)
def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy):
# Output.
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True):
super(GPTModel, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process,
)
self.initialize_word_embeddings(init_method_normal)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(
self,
input_ids,
position_ids,
attention_mask,
labels=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
if self.post_process:
return post_language_model_processing(
lm_output, labels, self.word_embeddings_weight(), self.parallel_output, self.fp16_lm_cross_entropy
)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
def gpt_model_provider(pre_process=True, post_process=True):
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process)
return model
"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`"""
import torch
from apex.transformer import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
return gathered
...@@ -33,6 +33,12 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -33,6 +33,12 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda( std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -119,6 +125,25 @@ void multi_tensor_lamb_cuda( ...@@ -119,6 +125,25 @@ void multi_tensor_lamb_cuda(
const float max_grad_norm, const float max_grad_norm,
at::optional<bool> use_nvlamb_python); at::optional<bool> use_nvlamb_python);
void multi_tensor_lamb_mp_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr,
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
at::Tensor max_grad_norm,
at::optional<bool> use_nvlamb_python,
at::Tensor found_inf,
at::Tensor inv_scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
...@@ -128,6 +153,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -128,6 +153,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors"); "Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda, m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
"Computes L2 norm for a list of contiguous tensors and does scaling"); "Computes L2 norm for a list of contiguous tensors and does scaling");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda, m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
...@@ -142,4 +169,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -142,4 +169,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Compute and apply gradient update to parameters for Adam optimizer"); "Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer"); "Computes and apply update for LAMB optimizer");
m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda,
"Computes and apply update for LAMB optimizer");
} }
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
/* coding=utf-8 /* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <THC/THC.h>
#include "compat.h" #include "compat.h"
#include <assert.h> #include <assert.h>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment