Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
import os
import sys
import unittest
from packaging.version import Version, parse
import torch
from torch import distributed as dist
from torch.utils import collect_env
from torch.testing._internal import common_utils
from torch.testing._internal import common_distributed
HAS_TORCH_UCC = None
try:
import torch_ucc
HAS_TORCH_UCC = True
except ImportError:
HAS_TORCH_UCC = False
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
_driver_version = None
if torch.cuda.is_available():
if collect_env.get_nvidia_driver_version(collect_env.run) != None:
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
else:
_driver_version = None
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
class DistributedTestBase(common_distributed.MultiProcessTestCase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def setUp(self) -> None:
super().setUp()
self._setup_pre_spawn()
self._spawn_processes()
def tearDown(self) -> None:
super().tearDown()
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)
@property
def init_method(self):
return f"{common_utils.FILE_SCHEMA}{self.file_name}"
@classmethod
def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name)
self.assertTrue(torch.cuda.is_available())
self.assertTrue(hasattr(self, "DISTRIBUTED_BACKEND"))
self.rank = rank
self.file_name = file_name
print(f"[dist init] rank = {self.rank}, world_size = {self.world_size}")
try:
dist.init_process_group(
init_method=self.init_method,
backend=self.DISTRIBUTED_BACKEND,
world_size=int(self.world_size),
rank=self.rank,
)
except RuntimeError as e:
if "recompile" in e.args[0]:
print(f"Backend of {self.DISTRIBUTED_BACKEND} not available")
sys.exit(0)
raise
torch.cuda.set_device(self.rank % torch.cuda.device_count())
dist.barrier()
self.run_test(test_name, pipe)
dist.barrier()
dist.destroy_process_group()
sys.exit(0)
def _setup_pre_spawn(self):
pass
class NcclDistributedTestBase(DistributedTestBase):
DISTRIBUTED_BACKEND = "nccl"
@unittest.skipUnless(
HAS_TORCH_UCC,
"Requires [`torch_ucc`](https://github.com/facebookresearch/torch_ucc)",
)
@unittest.skipUnless(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
"See https://github.com/openucx/ucc/issues/496",
)
class UccDistributedTestBase(DistributedTestBase):
DISTRIBUTED_BACKEND = "ucc"
def _setup_pre_spawn(self) -> None:
self.master_addr = "localhost"
os.environ["MASTER_ADDR"] = "localhost"
self._has_master_port = "MASTER_PORT" in os.environ
if self._has_master_port:
self.master_port = os.environ["MASTER_PORT"]
else:
try:
from caffe2.torch.fb.common.utils import get_free_port
self.master_port = str(get_free_port())
except ImportError:
self.master_port = "12375"
os.environ["MASTER_PORT"] = self.master_port
self._has_ucx_tls = "UCX_TLS" in os.environ
if not self._has_ucx_tls:
os.environ["UCX_TLS"] = "tcp,cuda"
print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ["UCX_TLS"]))
def tearDown(self) -> None:
super().tearDown()
if not self._has_master_port:
del os.environ["MASTER_PORT"]
if not self._has_ucx_tls:
del os.environ["UCX_TLS"]
@property
def init_method(self):
return "tcp://localhost:" + os.environ["MASTER_PORT"]
import contextlib
import torch
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer import tensor_parallel
from apex.transformer.enums import AttnMaskType
from apex.transformer.enums import ModelType
from apex.transformer.layers import FusedLayerNorm as LayerNorm
from apex.transformer.testing.global_vars import get_args
from .standalone_gpt import get_language_model, get_linear_layer, init_method_normal, parallel_lm_logits, scaled_init_method_normal
from .standalone_gpt import MegatronModule
from apex.transformer.testing.standalone_transformer_lm import (
MegatronModule,
get_language_model,
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
parallel_lm_logits,
)
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
......@@ -23,6 +33,7 @@ def bert_extended_attention_mask(attention_mask):
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
......@@ -32,6 +43,7 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -56,13 +68,18 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(
hidden_size, eps=layernorm_epsilon, sequence_parallel_enabled=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
......@@ -73,6 +90,7 @@ class BertLMHead(MegatronModule):
bias=self.bias)
return output
def post_language_model_processing(lm_output, pooled_output,
lm_head, binary_head,
lm_labels,
......@@ -87,8 +105,12 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0, 1).contiguous()
# lm_logits: [s b h] lm_labels: [s b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......@@ -106,7 +128,8 @@ class BertModel(MegatronModule):
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
post_process=True,
cpu_offload=False):
super(BertModel, self).__init__()
args = get_args()
......@@ -141,39 +164,43 @@ class BertModel(MegatronModule):
init_method)
self._binary_head_key = 'binary_head'
self.forward_context = contextlib.nullcontext
if cpu_offload:
self.forward_context = torch.autograd.graph.save_on_cpu
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
with self.forward_context():
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
# NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort.
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -196,6 +223,7 @@ class BertModel(MegatronModule):
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
# NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort.
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -212,6 +240,16 @@ class BertModel(MegatronModule):
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def bert_model_provider(pre_process=True, post_process=True):
model = BertModel(num_tokentypes=0, add_binary_head=False, pre_process=pre_process, post_process=post_process)
def bert_model_provider(pre_process=True, post_process=True, cpu_offload=False):
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
cpu_offload=cpu_offload,
)
return model
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,1422 +12,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import enum
import math
import contextlib
import torch
import torch.nn.functional as F
import apex.transformer.utils
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.transformer.functional import FusedScaleMaskSoftmax
from apex.transformer.enums import AttnMaskType
from apex.transformer.enums import ModelType
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state
from apex.transformer.testing.global_vars import get_args
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="",
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
def word_embeddings_weight(self):
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) or \
parallel_state.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception("word_embeddings_weight() called for last "
"stage, but share_word_embeddings is false")
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception("initialize_word_embeddings() was called but "
"share_word_embeddings is false")
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if parallel_state.is_pipeline_last_stage():
assert not parallel_state.is_pipeline_first_stage()
self._word_embeddings_for_head_key = "word_embeddings_for_head"
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
use_cpu_initialization=args.use_cpu_initialization)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and \
not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and \
parallel_state.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=parallel_state.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=parallel_state.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
super().__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization)
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
):
super().__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = apex.transformer.utils.divide(projection_size, world_size)
self.hidden_size_per_attention_head = apex.transformer.utils.divide(projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
use_cpu_initialization=args.use_cpu_initialization
)
# Inference key-value memory
self.inference_key_memory = None
self.inference_value_memory = None
self.inference_current_sequence_len = 0
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if set_inference_key_value_memory:
assert inference_max_sequence_len and inference_max_sequence_len > 0
self.inference_key_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1))
self.inference_value_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1))
self.inference_current_sequence_len = 0
# Some consistency check.
if inference_max_sequence_len:
assert self.inference_current_sequence_len < self.inference_key_memory.size(0)
assert inference_max_sequence_len == self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if not inference_max_sequence_len:
self.inference_key_memory = None
self.inference_value_memory = None
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ===================================================
# Adjust key, value, and attention mask for inference
# ===================================================
if inference_max_sequence_len:
# Adjust the range variables.
start = self.inference_current_sequence_len
self.inference_current_sequence_len += key_layer.size(0)
end = self.inference_current_sequence_len
# Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
# Adjust attention mask
attention_mask = attention_mask[..., start:end, :end]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
@torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
):
args = get_args()
super().__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn
)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
layernorm_output,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout)
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(
self,
init_method,
output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=True,
):
super().__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpointing flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
num_layers = args.num_layers
# Number of layers.
assert (
num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0
), "num_layers must be divisible by pipeline_model_parallel_size"
self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size()
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
"num_layers_per_stage must be divisible by " "virtual_pipeline_model_parallel_size"
)
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size
) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
def _get_layer(self, layer_number):
return self.layers[layer_number]
from apex.transformer.testing.standalone_transformer_lm import MegatronModule
from apex.transformer.testing.standalone_transformer_lm import parallel_lm_logits
from apex.transformer.testing.standalone_transformer_lm import post_language_model_processing
from apex.transformer.testing.standalone_transformer_lm import get_language_model
from apex.transformer.testing.standalone_transformer_lm import init_method_normal
from apex.transformer.testing.standalone_transformer_lm import (
scaled_init_method_normal,
)
def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
def distribute_checkpointed_activations_helper(layer_number):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = layer_number > 0
is_first_pipeline_stage = parallel_state.get_pipeline_model_parallel_rank() == 0
return self.distribute_checkpointed_activations and (
not_first_layer_in_pipeline_stage or is_first_pipeline_stage
)
if self.activations_checkpoint_method == "uniform":
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == "block":
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
distribute_checkpointed_activations_helper(l),
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
else:
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
):
# Checks.
if inference_max_sequence_len:
assert self.activations_checkpoint_method is None, "inference does not work with activation checkpointing"
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
num_tokentypes,
add_pooler,
encoder_attn_mask_type,
init_method=None,
scaled_init_method=None,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
):
"""Build language model and return along with the key to save."""
def gpt_model_provider(pre_process: bool = True, post_process: bool = True, cpu_offload: bool = False,) -> "GPTModel":
args = get_args()
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
cpu_offload=args.cpu_offload,
)
# key used for checkpoints.
language_model_key = "language_model"
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0
):
super().__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method,
use_cpu_initialization=args.use_cpu_initialization
)
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
print("FINISH WORD EMBEDDING", self.word_embeddings)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "word_embeddings" in key:
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "position_embeddings" in key:
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if "tokentype_embeddings" in key:
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print(
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
)
return model
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
num_tokentypes:int = 0,
parallel_output: bool = True,
pre_process: bool = True,
post_process: bool = True,
cpu_offload: bool = False,
):
super().__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
self._embedding_key = "embedding"
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = "encoder"
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert (
args.pipeline_model_parallel_size == 1
), "pipeline parallelism is not supported in the presence of decoder"
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._decoder_key = "decoder"
else:
self.decoder = None
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = "pooler"
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with both encoder and decoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, "input_tensor should only be length 1 for stage with only encoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception("input_tensor must have either length 1 or 2")
else:
raise Exception("Stage must have at least either encoder or decoder")
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids)
else:
encoder_input = None
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids, dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(
decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_encoder:
state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
if self.add_decoder:
state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "_embeddings" in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder.
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif "transformer" in state_dict:
state_dict_ = state_dict["transformer"]
else:
# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if "transformer." in key:
state_dict_[key.split("transformer.")[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if ".attention." in key:
state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
self.forward_context = contextlib.nullcontext
if cpu_offload:
self.forward_context = torch.autograd.graph.save_on_cpu
# Pooler.
if self.post_process:
if self.add_pooler:
assert "pooler" in state_dict, "could not find data for pooler in the checkpoint"
self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
# Decoder.
if self.add_decoder:
assert "decoder" in state_dict, "could not find data for pooler in the checkpoint"
self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)
def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy):
# Output.
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True):
super(GPTModel, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
......@@ -1439,7 +70,9 @@ class GPTModel(MegatronModule):
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers),
scaled_init_method=scaled_init_method_normal(
args.init_method_std, args.num_layers
),
pre_process=self.pre_process,
post_process=self.post_process,
)
......@@ -1457,48 +90,22 @@ class GPTModel(MegatronModule):
attention_mask,
labels=None,
tokentype_ids=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None,
inference_params=None,
):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
)
if self.post_process:
return post_language_model_processing(
lm_output, labels, self.word_embeddings_weight(), self.parallel_output, self.fp16_lm_cross_entropy
)
else:
return lm_output
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict(
destination, prefix, keep_vars
with self.forward_context():
lm_output = self.language_model(
input_ids, position_ids, attention_mask, inference_params=inference_params
)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
def gpt_model_provider(pre_process=True, post_process=True):
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process)
return model
if self.post_process:
return post_language_model_processing(
lm_output,
# note(mkozuki): Am I overlooking some order of dim change?
labels.t().contiguous(),
self.word_embeddings_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy,
)
else:
return lm_output
# coding=utf-8
# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import enum
import math
import contextlib
import json
import torch
import torch.nn.functional as F
import apex.transformer.utils
from apex.transformer.layers import FusedLayerNorm as LayerNorm
from apex.transformer.functional import FusedScaleMaskSoftmax
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.layers import ColumnParallelLinear
from apex.transformer.tensor_parallel.layers import RowParallelLinear
from apex.transformer.tensor_parallel.layers import VocabParallelEmbedding
from apex.transformer.tensor_parallel.mappings import scatter_to_sequence_parallel_region
from apex.transformer import parallel_state
from apex.transformer.testing.global_vars import get_args
from apex.transformer.enums import ModelType
from apex.transformer.enums import LayerType
from apex.transformer.enums import AttnType
from apex.transformer.enums import AttnMaskType
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
def param_is_not_shared(param: torch.Tensor) -> bool:
return getattr(param, "shared", False)
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support for pipelining."""
def __init__(self, share_word_embeddings: bool = True) -> None:
super().__init__()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self):
if self.pre_process:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, but share_word_embeddings is false')
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception("initialize_word_embeddings() was called but share_word_embeddings is false")
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
# using pipeline parallelism.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if parallel_state.is_pipeline_last_stage() and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and self.pre_process:
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight(),
group=parallel_state.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if parallel_state.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight,
group=parallel_state.get_position_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
# NOTE(mkozuki): Avoid inplace op.
def attention_mask_func(attention_scores: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
# attention_scores.masked_fill_(attention_mask, -10000.0)
# return attention_scores
return attention_scores.masked_fill(attention_mask, -10000.0)
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
super().__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
sequence_parallel_enabled=args.sequence_parallel,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class CoreAttention(MegatronModule):
def __init__(self, layer_number, attn_mask_type=AttnMaskType.padding):
super().__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = apex.transformer.utils.divide(
projection_size, world_size
)
self.hidden_size_per_attention_head = apex.transformer.utils.divide(
projection_size, args.num_attention_heads
)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(
args.num_attention_heads, world_size
)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer, value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding,
):
super().__init__()
args = get_args()
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = apex.transformer.utils.divide(
projection_size, args.num_attention_heads
)
self.num_attention_heads_per_partition = apex.transformer.utils.divide(
args.num_attention_heads, world_size
)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
else:
assert attention_type == AttnType.cross_attn
self.query = ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
self.key_value = ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method,
no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=args.sequence_parallel,
)
self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == "selective"
# Output.
self.dense = RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
sequence_parallel_enabled=args.sequence_parallel,
)
def _checkpointed_attention_forward(
self, query_layer, key_layer, value_layer, attention_mask
):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
return output_
hidden_states = tensor_parallel.checkpoint(
custom_forward, False, query_layer, key_layer, value_layer, attention_mask
)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def forward(
self, hidden_states, attention_mask, encoder_output=None, inference_params=None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(
query_layer,
key_layer,
value_layer,
) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer,
value_layer,
) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]
# ==================================
# core attention computation
# ==================================
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask
)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
init_method,
output_layer_init_method,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.0,
):
args = get_args()
super().__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = (
args.apply_residual_connection_post_layernorm
)
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type,
)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# note(mkozuki)
# self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
assert drop_path_rate <= 0.0
self.drop_path = None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn,
)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
# MLP
# note(mkozuki)
assert args.num_experts is None
# if args.num_experts is not None:
# self.mlp = SwitchMLP(init_method, output_layer_init_method)
# else:
# self.mlp = ParallelMLP(init_method, output_layer_init_method)
self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = (
contextlib.nullcontext if use_nvfuser else torch.enable_grad
)
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
layernorm_output, attention_mask, inference_params=inference_params
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
bias_dropout_add_func = get_bias_dropout_add(self.training)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout,
)
else:
out = torch.nn.functional.dropout(
attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training,
)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = self.inter_attention(
layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output
)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout,
)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout,
)
else:
out = torch.nn.functional.dropout(
mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training
)
output = residual + self.drop_path(out)
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(
self,
init_method,
output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True,
post_process=True,
drop_path_rate=0.0,
):
super().__init__()
args = get_args()
self.layer_type = layer_type
self.model_type = args.model_type
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = (
args.distribute_saved_activations and not args.sequence_parallel
)
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder
)
self.drop_path_rates = [
rate.item()
for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)
]
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1],
)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
"num_layers_per_stage must be divisible by "
"virtual_pipeline_model_parallel_size"
)
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = (
self.num_layers // args.virtual_pipeline_model_parallel_size
)
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size
) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if (
args.model_type == ModelType.encoder_and_decoder
and parallel_state.get_pipeline_model_parallel_world_size() > 1
):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = (
parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
)
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]
)
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
# no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel_enabled=args.sequence_parallel,
)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(
self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.recompute_method == "uniform":
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.random.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
l += self.recompute_num_layers
elif self.recompute_method == "block":
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
hidden_states = tensor_parallel.random.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
# hidden_states: [s, b, h]
# Checks.
if inference_params:
assert (
self.recompute_granularity is None
), "inference does not work with activation checkpointing"
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
# hidden_states = mpu.make_viewless_tensor(hidden_states, requires_grad=True, keep_graph=True)
if self.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = contextlib.nullcontext()
with rng_context:
# Forward pass.
if self.recompute_granularity == "full":
hidden_states = self._checkpointed_forward(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask
)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
)
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage
else args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = (
args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
)
assert args.num_layers % num_ranks_in_encoder == 0, (
"num_layers (%d) must be divisible by number of ranks given to encoder (%d)"
% (
args.num_layers,
num_ranks_in_encoder,
)
)
assert args.num_layers % num_ranks_in_decoder == 0, (
"num_layers (%d) must be divisible by number of ranks given to decoder (%d)"
% (
args.num_layers,
num_ranks_in_decoder,
)
)
if parallel_state.is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and parallel_state.get_pipeline_model_parallel_rank() == 0
else args.num_layers // num_ranks_in_encoder
)
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert (
args.num_layers % args.transformer_pipeline_model_parallel_size == 0
), "num_layers must be divisible by transformer_pipeline_model_parallel_size"
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and parallel_state.get_pipeline_model_parallel_rank() == 0
else args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
num_layers = args.num_layers
return num_layers
class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def __init__(self, layer_number):
super().__init__()
self.layer_number = layer_number
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
):
return hidden_states.clone()
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
args = get_args()
# Parallel logits.
if args.async_tensor_model_parallel_allreduce or args.sequence_parallel:
input_parallel = input_
model_parallel = parallel_state.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = (
args.async_tensor_model_parallel_allreduce
and model_parallel
and not args.sequence_parallel
)
else:
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
# logits_parallel = tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.apply(
# input_parallel, word_embeddings_weight, bias, args.gradient_accumulation_fusion, async_grad_allreduce, args.sequence_parallel)
logits_parallel = (
tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce(
input_parallel,
word_embeddings_weight,
bias,
args.gradient_accumulation_fusion,
async_grad_allreduce,
args.sequence_parallel,
)
)
# Gather if needed.
if parallel_output:
return logits_parallel
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(
num_tokentypes,
add_pooler,
encoder_attn_mask_type,
init_method=None,
scaled_init_method=None,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True,
post_process=True,
):
"""Build language model and return along with the key to save."""
args = get_args()
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
# Language model.
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
pre_process=pre_process,
post_process=post_process,
)
# key used for checkpoints.
language_model_key = "language_model"
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super().__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = tensor_parallel.mappings.gather_from_sequence_parallel_region(hidden_states)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0,
):
super().__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method
)
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size
)
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.hidden_size
)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print(
"adding embedding for {} tokentypes".format(num_tokentypes), flush=True
)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.sequence_parallel:
embeddings = scatter_to_sequence_parallel_region(embeddings)
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False,
pre_process=True,
post_process=True,
):
super().__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings.
if self.pre_process:
self.embedding = Embedding(
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
self._embedding_key = "embedding"
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = "encoder"
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder:
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._decoder_key = "decoder"
else:
self.decoder = None
if self.post_process:
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = "pooler"
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with both encoder and decoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert (
len(input_tensor) == 1
), "input_tensor should only be length 1 for stage with only encoder"
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception("input_tensor must have either length 1 or 2")
else:
raise Exception("Stage must have at least either encoder or decoder")
def forward(
self,
enc_input_ids,
enc_position_ids,
enc_attn_mask,
dec_input_ids=None,
dec_position_ids=None,
dec_attn_mask=None,
enc_dec_attn_mask=None,
tokentype_ids=None,
inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None,
output_enc_hidden=False,
):
args = get_args()
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids
)
else:
encoder_input = None
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
encoder_output = self.encoder(
encoder_input, enc_attn_mask, inference_params=inference_params
)
else:
encoder_output = self.encoder_hidden_state
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and self.post_process:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(dec_input_ids, dec_position_ids)
else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(
decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
)
if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def post_language_model_processing(
lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy
):
# Output.
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
if labels is None:
return output
else:
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
return loss
def module_size(m: torch.nn.Module, only_trainable: bool = False):
"""
returns the total number of parameters used by `m` (only counting
shared parameters once); if `only_trainable` is True, then only
includes parameters with `requires_grad = True`
"""
parameters = list(m.parameters())
if only_trainable:
parameters = [p for p in parameters if p.requires_grad]
unique = {p.data_ptr(): p for p in parameters}.values()
return sum(p.numel() for p in unique)
......@@ -6,7 +6,9 @@ from apex.transformer import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
......@@ -19,7 +21,9 @@ def divide(numerator, denominator):
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
......@@ -30,7 +34,15 @@ def gather_split_1d_tensor(tensor):
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(
gathered,
tensor,
group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
......@@ -1227,5 +1227,3 @@ void cuda_rms_norm_gradient(
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL);
)
}
#include <torch/extension.h>
#include <cstdio>
#include <vector>
void wgrad_gemm_accum_fp32_cuda_stub(
at::Tensor &input_2d,
at::Tensor &d_output_2d,
at::Tensor &d_weight
);
void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input_2d,
at::Tensor &d_output_2d,
at::Tensor &d_weight
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32");
m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16");
}
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 inputs and BF16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
at::BFloat16* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 inputs and FP16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
at::Half* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper_fp16(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}
template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
wgrad_gemm_accum_fp16_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
in_dim,
hidden_dim,
out_dim);
);
}
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP32 wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha,
float *A,
int lda,
float *B,
int ldb,
const float *beta,
float *C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}
template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32",
wgrad_gemm_accum_fp32_cuda<scalar_t_0>(
input_2d.data_ptr<scalar_t_0>(),
d_output_2d.data_ptr<scalar_t_0>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
......@@ -62,7 +62,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
// create output/workspace tensor
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({reserved_size}, inputs[0].type());
auto reserved_space = at::empty({static_cast<long>(reserved_size)}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
......@@ -135,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
// auto work_space = at::empty({work_size*4}, at::kByte);
auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());
auto work_space = at::empty({static_cast<long>(work_size / sizeof(scalar_t))}, inputs[0].type());
auto result = mlp_bp<scalar_t>(
inputs[0].data_ptr<scalar_t>(),
......
......@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
BLOCK_SIZE,
......
......@@ -63,7 +63,7 @@ void multi_tensor_apply(
// TODO: Print which tensor fails.
bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
......
......@@ -112,6 +112,38 @@
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
......
......@@ -44,7 +44,6 @@ Some other useful material, including GTC 2019 and Pytorch DevCon 2019 Slides, c
:caption: Deprecated mixed precision API
fp16_util
.. reparameterization
.. RNN
Indices and tables
......
# Base image must at least have pytorch and CUDA installed.
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:19.07-py3
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:22.02-py3
FROM $BASE_IMAGE
ARG BASE_IMAGE
RUN echo "Installing Apex on top of ${BASE_IMAGE}"
......
......@@ -3,3 +3,4 @@ tqdm>=4.28.1
numpy>=1.15.3
PyYAML>=5.1
pytest>=3.5.1
packaging>=14.0
......@@ -31,6 +31,7 @@ if os.path.exists(context_file):
found_Backward_Pass_Guard = True
break
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
......@@ -59,17 +60,21 @@ if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print('\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
......@@ -79,28 +84,14 @@ elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
'used by default in ROCm PyTorch\n')
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")
raise RuntimeError(
"Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/"
)
cmdclass = {}
ext_modules = []
extras = {}
if "--pyprof" in sys.argv:
string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
"soon be removed from Apex. Please visit\n" + \
"https://github.com/NVIDIA/PyProf\n" + \
"for the latest version."
warnings.warn(string, DeprecationWarning)
with open('requirements.txt') as f:
required_packages = f.read().splitlines()
extras['pyprof'] = required_packages
try:
sys.argv.remove("--pyprof")
except:
pass
else:
warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!")
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0:
......@@ -109,35 +100,7 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv:
sys.argv.remove("--cpp_ext")
ext_modules.append(
CppExtension('apex_C',
['csrc/flatten_unflatten.cpp',]))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " +
"not match the version used to compile Pytorch binaries. " +
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
"In some cases, a minor-version mismatch will not cause later errors: " +
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk).")
ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"]))
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
......@@ -146,13 +109,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_1 = ["-DVERSION_GE_1_1"]
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
version_ge_1_3 = ['-DVERSION_GE_1_3']
version_ge_1_3 = ["-DVERSION_GE_1_3"]
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_ge_1_5 = ["-DVERSION_GE_1_5"]
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
......@@ -393,11 +356,19 @@ if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")):
generator_flag = ["-DNEW_GENERATOR_PATH"]
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
if "--fast_layer_norm" in sys.argv:
sys.argv.remove("--fast_layer_norm")
raise_if_cuda_home_none("--fast_layer_norm")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
......@@ -409,28 +380,16 @@ if "--fast_layer_norm" in sys.argv:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
ext_modules.append(
CUDAExtension(name='fast_layer_norm',
sources=['apex/contrib/csrc/layer_norm/ln_api.cpp',
'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu',
'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu',
],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_BFLOAT16_OPERATORS__',
'-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-U__CUDA_NO_BFLOAT162_OPERATORS__',
'-U__CUDA_NO_BFLOAT162_CONVERSIONS__',
'-I./apex/contrib/csrc/layer_norm/',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
if "--fmha" in sys.argv:
sys.argv.remove("--fmha")
raise_if_cuda_home_none("--fmha")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
......@@ -532,54 +491,104 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
if "--transducer" in sys.argv:
sys.argv.remove("--transducer")
raise_if_cuda_home_none("--transducer")
ext_modules.append(
CUDAExtension(
name="transducer_joint_cuda",
sources=[
"apex/contrib/csrc/transducer/transducer_joint.cpp",
"apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros + generator_flag,
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + generator_flag),
},
include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")],
)
)
ext_modules.append(
CUDAExtension(
name="transducer_loss_cuda",
sources=[
"apex/contrib/csrc/transducer/transducer_loss.cpp",
"apex/contrib/csrc/transducer/transducer_loss_kernel.cu",
],
include_dirs=[os.path.join(this_dir, "csrc")],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros),
},
)
)
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='transducer_joint_cuda',
sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': ['-O3'] + version_dependent_macros},
include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
ext_modules.append(
CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
# note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`.
if "--fast_bottleneck" in sys.argv:
sys.argv.remove("--fast_bottleneck")
raise_if_cuda_home_none("--fast_bottleneck")
if check_cudnn_version_and_warn("--fast_bottleneck", 8400):
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append(
CUDAExtension(
name="fast_bottleneck",
sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
raise_if_cuda_home_none("--peer_memory")
ext_modules.append(
CUDAExtension(
name="peer_memory_cuda",
sources=[
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"apex/contrib/csrc/peer_memory/peer_memory.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
raise_if_cuda_home_none("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p_cuda",
sources=[
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--fused_conv_bias_relu" in sys.argv:
sys.argv.remove("--fused_conv_bias_relu")
raise_if_cuda_home_none("--fused_conv_bias_relu")
if check_cudnn_version_and_warn("--fused_conv_bias_relu", 8400):
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append(
CUDAExtension(name='fast_bottleneck',
sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))
CUDAExtension(
name="fused_conv_bias_relu",
sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"],
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
setup(
name='apex',
version='0.1',
packages=find_packages(exclude=('build',
'csrc',
'include',
'tests',
'dist',
'docs',
'tests',
'examples',
'apex.egg-info',)),
description='PyTorch Extensions written by NVIDIA',
name="apex",
version="0.1",
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",)
),
description="PyTorch Extensions written by NVIDIA",
ext_modules=ext_modules,
cmdclass=cmdclass,
#cmdclass={'build_ext': BuildExtension} if ext_modules else {},
......
......@@ -8,7 +8,7 @@ import torch.optim as optim
from apex import amp
from utils import common_init, FLOAT
from apex.testing.common_utils import skipFlakyTest
class MyModel(torch.nn.Module):
def __init__(self):
......@@ -161,6 +161,7 @@ class TestCheckpointing(unittest.TestCase):
# skip tests for different opt_levels
continue
@skipFlakyTest
def test_loss_scale_decrease(self):
num_losses = 3
nb_decrease_loss_scales = [0, 1, 2]
......
......@@ -4,7 +4,7 @@ import unittest
import torch
import apex
from apex.testing.common_utils import skipFlakyTest
class TestFusedLayerNorm(unittest.TestCase):
dtype = torch.float
......@@ -180,6 +180,7 @@ class TestMixedFusedRMSNormElemWise(TestFusedRMSNorm):
elementwise_affine = True
mixed_fused = True
@skipFlakyTest
class TestFusedRMSNormElemWiseHalf(TestFusedRMSNormElemWise):
dtype = torch.half
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
......@@ -188,6 +189,7 @@ class TestFusedRMSNormElemWiseHalf(TestFusedRMSNormElemWise):
self.skipTest("Skip to save time")
@skipFlakyTest
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
......
......@@ -7,7 +7,7 @@ import torch
from torch import nn
from apex.mlp import MLP
from apex.testing.common_utils import skipIfRocm
from apex.testing.common_utils import skipFlakyTest
batch_size = 1024
mlp_sizes = [480, 1024, 1024, 512, 256, 1]
......@@ -18,7 +18,7 @@ class TestMLP(unittest.TestCase):
def test_creation(self):
MLP(mlp_sizes)
@skipIfRocm
@skipFlakyTest
def test_numeric(self):
mlp = MLP(mlp_sizes).cuda()
......@@ -53,7 +53,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
@skipIfRocm
@skipFlakyTest
def test_no_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda()
......@@ -91,7 +91,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=100)
@skipIfRocm
@skipFlakyTest
def test_with_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda()
......@@ -134,7 +134,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
@skipIfRocm
@skipFlakyTest
def test_no_grad(self):
mlp = MLP(mlp_sizes).cuda()
......@@ -165,7 +165,6 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
@skipIfRocm
def test_performance_half(self):
mlp = MLP(mlp_sizes).cuda().half()
......@@ -195,7 +194,7 @@ class TestMLP(unittest.TestCase):
mlp.zero_grad()
test_loss.backward()
torch.cuda.profiler.start()
#torch.cuda.profiler.start()
torch.cuda.synchronize()
start_time = time()
for _ in range(num_iters):
......@@ -217,7 +216,7 @@ class TestMLP(unittest.TestCase):
torch.cuda.synchronize()
stop_time = time()
print(F"C++ MLP time {(stop_time - start_time) * 1000. / num_iters:.4f} ms")
torch.cuda.profiler.stop()
#torch.cuda.profiler.stop()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment