Commit 4b506832 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

decoder support in transformers

parent f5eac3d1
......@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
......@@ -182,16 +196,11 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
if args.scaled_masked_softmax_fusion:
if args.scaled_upper_triang_masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
else:
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
else:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args.scaled_upper_triang_masked_softmax_fusion = False
# Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
......@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
......@@ -330,16 +345,11 @@ def _add_training_args(parser):
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-scaled-masked-softmax-fusion',
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='scaled_masked_softmax_fusion')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
type=bool,
help='Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED')
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
......@@ -530,6 +540,10 @@ def _add_data_args(parser):
help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.")
group.add_argument('--encoder-seq-length', type=int, default=None,
help="Maximum encoder sequence length to process.")
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
......
......@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count,
int pad_batches)
{
......@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
mask += pad_first_batch * stride + local_idx;
src += first_batch * element_count + local_idx;
dst += first_batch * element_count + local_idx;
mask += pad_first_batch * element_count + local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
......@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
......@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
int thread_offset = first_batch * element_count + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
......@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const input_t *src,
const uint8_t *mask,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = batches * attn_heads * seq_len;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
......@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward(
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(seq_len%batches_per_block == 0);
dim3 blocks(seq_len/batches_per_block, attn_heads, batches);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
......@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = batches * attn_heads * seq_len;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
......@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
......
......@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int seq_len = input.size(2);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, seq_len, seq_len}, act_options);
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
......@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast<const half*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
seq_len,
seq_len,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
......@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int seq_len = output_grads.size(2);
TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3));
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
......@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
query_seq_len,
key_seq_len,
batches,
attn_heads);
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
......@@ -14,11 +14,13 @@
# limitations under the License.
import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
......@@ -38,15 +40,16 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
......@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
scale_t[0])
return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale):
def __init__(self, input_in_fp16, attn_mask_type,
scaled_masked_softmax_fusion, mask_func,
softmax_in_fp32, scale):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion
self.general_mask_fusion = general_mask_fusion
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
......@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
'softmax should be in fp32 when scaled'
def forward(self, input, mask):
# [b, np, s, s]
# [b, np, sq, sk]
data_size = input.size()
assert input.dim() == 4
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4
# invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \
input.size()[2] == input.size()[3]:
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3])
if self.input_in_fp16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else:
if self.input_in_fp16 and self.softmax_in_fp32:
......
......@@ -21,6 +21,7 @@ from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
......@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
self_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -21,6 +21,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
......@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None):
add_decoder=False, init_method=None,
scaled_init_method=None,
self_attn_mask_type=AttnMaskType.padding):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model.
args = [attention_mask_func, init_method, scaled_init_method]
......@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['self_attn_mask_type'] = self_attn_mask_type
kwargs['add_decoder'] = add_decoder
kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
......@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
......@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
init_method,
output_layer_init_method,
num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModelBase, self).__init__()
args = get_args()
......@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.self_attn_mask_type = self_attn_mask_type
self.add_decoder = add_decoder
self.add_pooler = add_pooler
# Embeddings.
......@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding'
# Transformer.
self.transformer = ParallelTransformer(
attention_mask_func, self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, language_model_input, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
self.encoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
self_attn_mask_type=self_attn_mask_type)
self._encoder_key = 'encoder'
# assuming pooler and decoder are in the last stage
# of the pipeline(to be revised)
if mpu.is_pipeline_last_stage():
# decoder
if self.add_decoder:
self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.causal)
self._decoder_key = 'decoder'
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attention_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings.
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input
(input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
transformer_input = embedding_output
encoder_input = embedding_output
else:
transformer_input = language_model_input
# Transformer.
transformer_output = self.transformer(transformer_input,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output,
pooling_sequence_index)
return transformer_output, pooled_output
encoder_input = enc_language_model_input
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage():
if self.add_pooler:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
return transformer_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
......@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
if 'encoder.' in key:
state_dict_[key.split('encoder.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
......@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
init_method,
output_layer_init_method,
num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type,
add_decoder=add_decoder,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, enc_input_ids, enc_position_ids, enc_attention_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
(enc_input_ids, enc_position_ids),
enc_attention_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
)
......@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0):
num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes)
num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
......@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
output_layer_init_method,
self_attn_mask_type=self_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
......@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=add_decoder,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, hidden_states, enc_attention_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0, enc_hidden_states=None,
output_enc_hidden=False):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
attention_mask,
enc_attention_mask,
dec_language_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
ouput_enc_hidden=output_enc_hidden
)
......@@ -14,7 +14,7 @@
# limitations under the License.
"""Transformer."""
import enum
import math
import torch
import torch.nn.functional as F
......@@ -23,6 +23,7 @@ from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
......@@ -85,12 +86,11 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(self, hidden_states):
......@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule):
return output, output_bias
class ParallelSelfAttention(MegatronModule):
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
......@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule):
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
......@@ -128,22 +130,40 @@ class ParallelSelfAttention(MegatronModule):
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * args.hidden_size,
gather_output=False,
init_method=init_method)
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
......@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
args.scaled_masked_softmax_fusion,
self.attn_mask_type,
args.masked_softmax_fusion,
self.attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -166,18 +186,18 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size();
input_shape = mixed_layer.size()
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
......@@ -188,8 +208,8 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
......@@ -199,39 +219,70 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
......@@ -246,41 +297,41 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result,
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
......@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
......@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
......@@ -321,21 +370,21 @@ class ParallelSelfAttention(MegatronModule):
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
......@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
......@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
......@@ -410,45 +461,62 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
output_layer_init_method,
layer_number)
self.self_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
self.self_attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
......@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
......@@ -470,16 +538,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
......@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, attention_mask_func,
init_method, output_layer_init_method):
init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule):
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number)
output_layer_init_method, layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask):
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
......@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks.
if layer_past is not None:
......@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule):
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
if get_key_value:
presents = []
......@@ -588,12 +688,14 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
if mpu.is_pipeline_last_stage():
# Reverting data format change [s b h] --> [b s h].
......
......@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'scaled_upper_triang_masked_softmax_fusion': True})
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment