"vscode:/vscode.git/clone" did not exist on "d481f2d88f16c6a6c1f143629876bf536072921b"
Commit 58edb19a authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into vision_transformer

parents adbba962 f34cc86b
......@@ -165,6 +165,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
......@@ -183,16 +197,11 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
if args.scaled_masked_softmax_fusion:
if args.scaled_upper_triang_masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
else:
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
else:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args.scaled_upper_triang_masked_softmax_fusion = False
# Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
......@@ -228,8 +237,14 @@ def _add_network_size_args(parser):
help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
......@@ -333,16 +348,11 @@ def _add_training_args(parser):
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-scaled-masked-softmax-fusion',
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='scaled_masked_softmax_fusion')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
type=bool,
help='Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED')
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
......@@ -537,7 +547,12 @@ def _add_data_args(parser):
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.")
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
......
......@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count,
int pad_batches)
{
......@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
mask += pad_first_batch * stride + local_idx;
src += first_batch * element_count + local_idx;
dst += first_batch * element_count + local_idx;
mask += pad_first_batch * element_count + local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
......@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
......@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
int thread_offset = first_batch * element_count + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
......@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const input_t *src,
const uint8_t *mask,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = batches * attn_heads * seq_len;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
......@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward(
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(seq_len%batches_per_block == 0);
dim3 blocks(seq_len/batches_per_block, attn_heads, batches);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
......@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = batches * attn_heads * seq_len;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
......@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
......
......@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int seq_len = input.size(2);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, seq_len, seq_len}, act_options);
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
......@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast<const half*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
seq_len,
seq_len,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
......@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int seq_len = output_grads.size(2);
TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3));
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
......@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
query_seq_len,
key_seq_len,
batches,
attn_heads);
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm
......@@ -143,6 +144,7 @@ class BertModelBase(MegatronModule):
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
......@@ -39,6 +40,7 @@ class ClassificationBase(MegatronModule):
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
......@@ -14,6 +14,7 @@
# limitations under the License.
import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......@@ -85,8 +86,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
......@@ -96,16 +96,16 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def __init__(
self,
input_in_fp16,
upper_triang_mask_fusion,
general_mask_fusion,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion
self.general_mask_fusion = general_mask_fusion
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
......@@ -115,23 +115,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, s, s]
# [b, np, sq, sk]
data_size = input.size()
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4
# invoke custom kernel
if (
self.input_in_fp16
and data_size[-1] <= 2048
and (self.upper_triang_mask_fusion or self.general_mask_fusion)
and input.size()[2] == input.size()[3]
):
if self.input_in_fp16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3])
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else:
if self.input_in_fp16 and self.softmax_in_fp32:
......
......@@ -21,6 +21,7 @@ from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
......@@ -69,6 +70,7 @@ class GPTModelBase(MegatronModule):
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -21,6 +21,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
......@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None):
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -51,15 +54,18 @@ def get_language_model(num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model.
args = [init_method, scaled_init_method]
args = [init_method, scaled_init_method, encoder_attn_mask_type]
kwargs = {}
cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
......@@ -273,7 +279,10 @@ class TransformerLanguageModelBase(MegatronModule):
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False):
super(TransformerLanguageModelBase, self).__init__()
args = get_args()
......@@ -281,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
# Embeddings.
......@@ -294,40 +306,83 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding'
# Transformer.
self.transformer = ParallelTransformer(
self.init_method, output_layer_init_method)
self._transformer_key = 'transformer'
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, language_model_input, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder'
# Decoder
if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings.
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input
(input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
transformer_input = embedding_output
encoder_input = embedding_output
else:
transformer_input = language_model_input
# Transformer.
transformer_output = self.transformer(transformer_input,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output,
pooling_sequence_index)
return transformer_output, pooled_output
return transformer_output
encoder_input = enc_language_model_input
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage():
if self.add_pooler:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -338,12 +393,17 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
if mpu.is_pipeline_last_stage():
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
......@@ -363,23 +423,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
assert 'pooler' in state_dict, \
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
......@@ -390,24 +471,37 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
(enc_input_ids, enc_position_ids),
enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
)
......@@ -419,10 +513,12 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
......@@ -443,10 +539,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def __init__(self,
init_method,
output_layer_init_method):
output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
init_method,
output_layer_init_method)
output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
......@@ -466,10 +564,12 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
......@@ -480,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
)
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
......@@ -38,6 +39,7 @@ class MultipleChoiceBase(MegatronModule):
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
......@@ -158,6 +159,7 @@ class IREncoderBertModel(MegatronModule):
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -14,7 +14,6 @@
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.functional as F
......@@ -23,6 +22,7 @@ from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......@@ -65,7 +65,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
......@@ -79,7 +79,7 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
......@@ -103,15 +103,18 @@ class ParallelMLP(MegatronModule):
return output, output_bias
class ParallelSelfAttention(MegatronModule):
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method, output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
......@@ -120,22 +123,40 @@ class ParallelSelfAttention(MegatronModule):
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * args.hidden_size,
gather_output=False,
init_method=init_method)
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
......@@ -145,8 +166,8 @@ class ParallelSelfAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
args.scaled_masked_softmax_fusion,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -158,7 +179,7 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
......@@ -191,39 +212,70 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
......@@ -251,6 +303,7 @@ class ParallelSelfAttention(MegatronModule):
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
......@@ -377,15 +430,18 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob):
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method, layer_number):
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
......@@ -397,33 +453,48 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(init_method,
output_layer_init_method,
layer_number)
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
self.self_attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
attention_output, presents = attention_output
......@@ -457,9 +528,31 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
......@@ -483,7 +576,9 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, init_method, output_layer_init_method):
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -501,7 +596,11 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method, output_layer_init_method, layer_number)
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......@@ -516,14 +615,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask):
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
......@@ -533,13 +636,13 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks.
if layer_past is not None:
......@@ -560,9 +663,14 @@ class ParallelTransformer(MegatronModule):
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
if get_key_value:
presents = []
......@@ -573,6 +681,8 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
......
......@@ -144,5 +144,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'scaled_upper_triang_masked_softmax_fusion': True})
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment