"sgl-kernel/vscode:/vscode.git/clone" did not exist on "1b9175cb23004e3a40dfb97ab80e7e45032c5359"
Commit f34cc86b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'transformer_decoder' into 'main'

Transformer decoder

See merge request ADLR/megatron-lm!205
parents be473a5b 4ae54b55
......@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
......@@ -183,15 +197,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
if args.scaled_masked_softmax_fusion:
if args.scaled_upper_triang_masked_softmax_fusion:
# Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
else:
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
else:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args.scaled_upper_triang_masked_softmax_fusion = False
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
......@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
......@@ -330,16 +345,11 @@ def _add_training_args(parser):
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-scaled-masked-softmax-fusion',
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='scaled_masked_softmax_fusion')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
type=bool,
help='Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED')
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
......@@ -529,7 +539,12 @@ def _add_data_args(parser):
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.")
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
......
......@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count,
int pad_batches)
{
......@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
mask += pad_first_batch * stride + local_idx;
src += first_batch * element_count + local_idx;
dst += first_batch * element_count + local_idx;
mask += pad_first_batch * element_count + local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
......@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
......@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
int thread_offset = first_batch * element_count + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
......@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const input_t *src,
const uint8_t *mask,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = batches * attn_heads * seq_len;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
......@@ -303,58 +300,58 @@ void dispatch_scaled_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(seq_len%batches_per_block == 0);
dim3 blocks(seq_len/batches_per_block, attn_heads, batches);
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
......@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = batches * attn_heads * seq_len;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
......@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
......
......@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int seq_len = input.size(2);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, seq_len, seq_len}, act_options);
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
......@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast<const half*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
seq_len,
seq_len,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
......@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int seq_len = output_grads.size(2);
TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3));
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
......@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
query_seq_len,
key_seq_len,
batches,
attn_heads);
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm
......@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
......@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
......@@ -14,8 +14,10 @@
# limitations under the License.
import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
......@@ -43,7 +45,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
......@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
scale_t[0])
return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale):
def __init__(self, input_in_fp16, attn_mask_type,
scaled_masked_softmax_fusion, mask_func,
softmax_in_fp32, scale):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion
self.general_mask_fusion = general_mask_fusion
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
......@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
'softmax should be in fp32 when scaled'
def forward(self, input, mask):
# [b, np, s, s]
# [b, np, sq, sk]
data_size = input.size()
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4
# invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \
input.size()[2] == input.size()[3]:
if self.input_in_fp16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3])
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else:
if self.input_in_fp16 and self.softmax_in_fp32:
......
......@@ -21,6 +21,7 @@ from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
......@@ -75,6 +76,7 @@ class GPTModelBase(MegatronModule):
attention_mask_func=gpt_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -21,6 +21,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
......@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None):
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -51,15 +54,19 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model.
args = [attention_mask_func, init_method, scaled_init_method]
args = [attention_mask_func, init_method,
scaled_init_method, encoder_attn_mask_type]
kwargs = {}
cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
......@@ -280,7 +287,10 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False):
super(TransformerLanguageModelBase, self).__init__()
args = get_args()
......@@ -288,6 +298,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
# Embeddings.
......@@ -301,41 +314,85 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding'
# Transformer.
self.transformer = ParallelTransformer(
attention_mask_func, self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'
self.encoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder'
# Decoder
if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, language_model_input, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings.
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input
(input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
transformer_input = embedding_output
encoder_input = embedding_output
else:
transformer_input = language_model_input
encoder_input = enc_language_model_input
# Transformer.
transformer_output = self.transformer(transformer_input,
attention_mask,
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output,
if mpu.is_pipeline_last_stage():
if self.add_pooler:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
return transformer_output, pooled_output
return transformer_output
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -346,13 +403,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler:
if mpu.is_pipeline_last_stage():
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
......@@ -371,23 +433,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
......@@ -399,25 +482,38 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
(enc_input_ids, enc_position_ids),
enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
)
......@@ -430,11 +526,13 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
......@@ -456,11 +554,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
......@@ -481,11 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
......@@ -496,5 +598,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
)
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
......@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
......@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -14,7 +14,6 @@
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.functional as F
......@@ -23,6 +22,7 @@ from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......@@ -71,7 +71,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
......@@ -85,13 +85,12 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(self, hidden_states):
# [s, b, 4hp]
......@@ -109,7 +108,7 @@ class ParallelMLP(MegatronModule):
return output, output_bias
class ParallelSelfAttention(MegatronModule):
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
......@@ -117,8 +116,10 @@ class ParallelSelfAttention(MegatronModule):
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
......@@ -128,20 +129,38 @@ class ParallelSelfAttention(MegatronModule):
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
......@@ -153,8 +172,8 @@ class ParallelSelfAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
args.scaled_masked_softmax_fusion,
self.attn_mask_type,
args.masked_softmax_fusion,
self.attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -166,14 +185,14 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size();
input_shape = mixed_layer.size()
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
......@@ -203,13 +222,14 @@ class ParallelSelfAttention(MegatronModule):
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
......@@ -232,6 +252,36 @@ class ParallelSelfAttention(MegatronModule):
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
......@@ -246,7 +296,6 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
......@@ -260,6 +309,7 @@ class ParallelSelfAttention(MegatronModule):
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
......@@ -272,15 +322,15 @@ class ParallelSelfAttention(MegatronModule):
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result,
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
......@@ -298,7 +348,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
......@@ -312,7 +361,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
......@@ -335,7 +383,7 @@ class ParallelSelfAttention(MegatronModule):
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
......@@ -348,7 +396,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
......@@ -389,16 +436,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
output_layer_init_method, layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
......@@ -410,30 +460,47 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
self.self_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number)
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
self.self_attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
......@@ -459,7 +526,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
......@@ -470,6 +537,28 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
......@@ -479,7 +568,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
......@@ -497,7 +586,9 @@ class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, attention_mask_func,
init_method, output_layer_init_method):
init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -516,7 +607,9 @@ class ParallelTransformer(MegatronModule):
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number)
output_layer_init_method, layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......@@ -531,14 +624,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask):
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
......@@ -548,13 +645,13 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks.
if layer_past is not None:
......@@ -575,9 +672,14 @@ class ParallelTransformer(MegatronModule):
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
if get_key_value:
presents = []
......@@ -588,6 +690,8 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
......
......@@ -144,5 +144,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'scaled_upper_triang_masked_softmax_fusion': True})
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment