Commit 4b506832 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

decoder support in transformers

parent f5eac3d1
...@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -164,6 +164,20 @@ def parse_args(extra_args_provider=None, defaults={},
_check_arg_is_not_none(args, req_arg) _check_arg_is_not_none(args, req_arg)
# Checks. # Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None: if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length assert args.max_position_embeddings >= args.seq_length
...@@ -183,15 +197,10 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -183,15 +197,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
if args.scaled_masked_softmax_fusion: # Load scaled_masked_softmax_fusion_kernels
if args.scaled_upper_triang_masked_softmax_fusion: if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
else:
fused_kernels.load_scaled_masked_softmax_fusion_kernel() fused_kernels.load_scaled_masked_softmax_fusion_kernel()
else:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args.scaled_upper_triang_masked_softmax_fusion = False
# Load mixed precision fused layer norm. # Load mixed precision fused layer norm.
if args.fp32_residual_connection: if args.fp32_residual_connection:
...@@ -227,8 +236,14 @@ def _add_network_size_args(parser): ...@@ -227,8 +236,14 @@ def _add_network_size_args(parser):
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. This is set to 4*hidden-size if not '
'provided')
group.add_argument('--num-attention-heads', type=int, default=None, group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.') help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head attention. '
'This is set to args.hidden_size // args.num_attention_heads if not provided.')
group.add_argument('--max-position-embeddings', type=int, default=None, group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. ' help='Maximum number of position embeddings to use. '
'This is the size of position embedding.') 'This is the size of position embedding.')
...@@ -330,16 +345,11 @@ def _add_training_args(parser): ...@@ -330,16 +345,11 @@ def _add_training_args(parser):
help='Exit the program after this many minutes.') help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None, group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.') help='Write TensorBoard logs to this directory.')
group.add_argument('--no-scaled-masked-softmax-fusion', group.add_argument('--no-masked-softmax-fusion',
action='store_false', action='store_false',
help='Disable fusion of query_key_value scaling, ' help='Disable fusion of query_key_value scaling, '
'masking, and softmax.', 'masking, and softmax.',
dest='scaled_masked_softmax_fusion') dest='masked_softmax_fusion')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
type=bool,
help='Use upper triangular version of fused '
'scale, mask, softmax fusion kernel (default for GPT). '
'- DEPRECATED')
group.add_argument('--no-bias-gelu-fusion', action='store_false', group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.', help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion') dest='bias_gelu_fusion')
...@@ -530,6 +540,10 @@ def _add_data_args(parser): ...@@ -530,6 +540,10 @@ def _add_data_args(parser):
help='Path to the BPE merge file.') help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None, group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.") help="Maximum sequence length to process.")
group.add_argument('--encoder-seq-length', type=int, default=None,
help="Maximum encoder sequence length to process.")
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.') help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1, group.add_argument('--short-seq-prob', type=float, default=0.1,
......
...@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const uint8_t *mask, const uint8_t *mask,
const acc_t scale, const acc_t scale,
int micro_batch_size, int micro_batch_size,
int stride,
int element_count, int element_count,
int pad_batches) int pad_batches)
{ {
...@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * stride + local_idx; src += first_batch * element_count + local_idx;
dst += first_batch * stride + local_idx; dst += first_batch * element_count + local_idx;
mask += pad_first_batch * stride + local_idx; mask += pad_first_batch * element_count + local_idx;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
...@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int micro_batch_size, int micro_batch_size,
int stride,
int element_count) int element_count)
{ {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
...@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx; int thread_offset = first_batch * element_count + local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
...@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const input_t *src, const input_t *src,
const uint8_t *mask, const uint8_t *mask,
const input_t scale, const input_t scale,
int softmax_elements, int query_seq_len,
int softmax_elements_stride, int key_seq_len,
int batches, int batches,
int attn_heads, int attn_heads,
int pad_batches) int pad_batches)
{ {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (softmax_elements == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
int log2_elements = log2_ceil(softmax_elements); int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements; int batch_count = batches * attn_heads * query_seq_len;
int batch_count = batches * attn_heads * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
...@@ -303,58 +300,58 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -303,58 +300,58 @@ void dispatch_scaled_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(seq_len%batches_per_block == 0); TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(seq_len/batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 1: // 2 case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
default: default:
break; break;
...@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
const acc_t scale, const acc_t scale,
int softmax_elements, int query_seq_len,
int softmax_elements_stride, int key_seq_len,
int batches, int batches,
int attn_heads) int attn_heads)
{ {
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (softmax_elements == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
int log2_elements = log2_ceil(softmax_elements); int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements; int batch_count = batches * attn_heads * query_seq_len;
int batch_count = batches * attn_heads * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
...@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 1: // 2 case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
default: default:
break; break;
......
...@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda( ...@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const int batches = input.size(0); const int batches = input.size(0);
const int pad_batches = mask.size(0); const int pad_batches = mask.size(0);
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
const int seq_len = input.size(2); const int query_seq_len = input.size(2);
TORCH_INTERNAL_ASSERT(seq_len <= 2048); const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1); TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len); TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::Tensor softmax_results =
torch::empty({batches, attn_heads, seq_len, seq_len}, act_options); torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
...@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda( ...@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor, scale_factor,
seq_len, query_seq_len,
seq_len, key_seq_len,
batches, batches,
attn_heads, attn_heads,
pad_batches); pad_batches);
...@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda( ...@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0); const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1); const int attn_heads = output_grads.size(1);
const int seq_len = output_grads.size(2); const int query_seq_len = output_grads.size(2);
TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3)); const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
...@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda( ...@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast<half*>(output_grads_ptr), reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor, scale_factor,
seq_len, query_seq_len,
seq_len, key_seq_len,
batches, batches,
attn_heads); attn_heads);
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
import torch import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
...@@ -43,7 +45,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) : ...@@ -43,7 +45,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
scale_t[0]) scale_t[0])
return input_grads, None return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
...@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) : ...@@ -71,24 +74,25 @@ class ScaledMaskedSoftmax(torch.autograd.Function) :
scale_t[0]) scale_t[0])
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(torch.nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking. attn_mask_type: attention mask type (pad or causal)
(used in gpt family networks)
mask_func: mask function to be applied. mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision. softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling. scale: scaling factor used in input tensor scaling.
""" """
def __init__(self, input_in_fp16, upper_triang_mask_fusion, def __init__(self, input_in_fp16, attn_mask_type,
general_mask_fusion, mask_func, softmax_in_fp32, scale): scaled_masked_softmax_fusion, mask_func,
softmax_in_fp32, scale):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion self.attn_mask_type = attn_mask_type
self.general_mask_fusion = general_mask_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
...@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -97,20 +101,26 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
'softmax should be in fp32 when scaled' 'softmax should be in fp32 when scaled'
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, s, s] # [b, np, sq, sk]
data_size = input.size() data_size = input.size()
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4 assert input.dim() == 4
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \ if self.input_in_fp16 and key_seq_len <= 2048 and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
input.size()[2] == input.size()[3]:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3]) if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size) probs = probs.view(*data_size)
else: else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale) probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
......
...@@ -21,6 +21,7 @@ from megatron import get_args ...@@ -21,6 +21,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
...@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule): ...@@ -75,6 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func=gpt2_attention_mask_func, attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
self_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -21,6 +21,7 @@ import torch.nn.functional as F ...@@ -21,6 +21,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
...@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -43,7 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None): add_decoder=False, init_method=None,
scaled_init_method=None,
self_attn_mask_type=AttnMaskType.padding):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -51,7 +54,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None: if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, scaled_init_method] args = [attention_mask_func, init_method, scaled_init_method]
...@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -60,6 +64,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes kwargs['num_tokentypes'] = num_tokentypes
kwargs['self_attn_mask_type'] = self_attn_mask_type
kwargs['add_decoder'] = add_decoder
kwargs['add_pooler'] = add_pooler kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage cls = TransformerLanguageModelFirstStage
...@@ -186,8 +192,6 @@ class Embedding(MegatronModule): ...@@ -186,8 +192,6 @@ class Embedding(MegatronModule):
if tokentype_ids is not None: if tokentype_ids is not None:
assert self.tokentype_embeddings is not None assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout. # Dropout.
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
...@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -281,6 +285,8 @@ class TransformerLanguageModelBase(MegatronModule):
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelBase, self).__init__() super(TransformerLanguageModelBase, self).__init__()
args = get_args() args = get_args()
...@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -288,6 +294,8 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.self_attn_mask_type = self_attn_mask_type
self.add_decoder = add_decoder
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
...@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -301,41 +309,87 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.transformer = ParallelTransformer( self.encoder = ParallelTransformer(
attention_mask_func, self.init_method, attention_mask_func,
output_layer_init_method) self.init_method,
self._transformer_key = 'transformer' output_layer_init_method,
self_attn_mask_type=self_attn_mask_type)
self._encoder_key = 'encoder'
# assuming pooler and decoder are in the last stage
# of the pipeline(to be revised)
if mpu.is_pipeline_last_stage():
# decoder
if self.add_decoder:
self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.causal)
self._decoder_key = 'decoder'
# Pooler. # Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler: if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, language_model_input, attention_mask, def forward(self, enc_language_model_input, enc_attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, dec_language_model_input=None, dec_attn_mask=None,
pooling_sequence_index=0): enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input (input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids, embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
transformer_input = embedding_output encoder_input = embedding_output
else: else:
transformer_input = language_model_input encoder_input = enc_language_model_input
# Transformer. # encoder.
transformer_output = self.transformer(transformer_input, if enc_hidden_states is None:
attention_mask, encoder_output = self.encoder(encoder_input,
enc_attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage():
pooled_output = self.pooler(transformer_output, if self.add_pooler:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index) pooling_sequence_index)
return transformer_output, pooled_output
return transformer_output # output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler:
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -346,13 +400,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._transformer_key] \ state_dict_[self._encoder_key] \
= self.transformer.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage():
if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -371,23 +430,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key] state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer. # Encoder.
if self._transformer_key in state_dict: if self._encoder_key in state_dict:
state_dict_ = state_dict[self._transformer_key] state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else: else:
# for backward compatibility. # for backward compatibility.
state_dict_ = {} state_dict_ = {}
for key in state_dict.keys(): for key in state_dict.keys():
if 'transformer.' in key: if 'encoder.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key] state_dict_[key.split('encoder.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler. # for backward compatibility.
if mpu.is_pipeline_last_stage() and self.add_pooler: state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase): class TransformerLanguageModel(TransformerLanguageModelBase):
...@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -400,24 +480,35 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type,
add_decoder=add_decoder,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
pooling_sequence_index=0): enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward( return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids), (enc_input_ids, enc_position_ids),
attention_mask, enc_attention_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
) )
...@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -430,12 +521,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0): num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False):
...@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -456,11 +549,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
def __init__(self, def __init__(self,
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method): output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method) output_layer_init_method,
self_attn_mask_type=self_attn_mask_type)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
...@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -481,20 +576,31 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=add_decoder,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, enc_attention_mask,
layer_past=None, get_key_value=False, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
pooling_sequence_index=0): enc_dec_attn_mask=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0, enc_hidden_states=None,
output_enc_hidden=False):
return super(TransformerLanguageModelLastStage, self).forward( return super(TransformerLanguageModelLastStage, self).forward(
hidden_states, hidden_states,
attention_mask, enc_attention_mask,
dec_language_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
ouput_enc_hidden=output_enc_hidden
) )
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Transformer.""" """Transformer."""
import enum
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -23,6 +23,7 @@ from megatron import get_args ...@@ -23,6 +23,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version from megatron.checkpointing import get_checkpoint_version
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
...@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule): ...@@ -71,7 +72,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
4 * args.hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True) skip_bias_add=True)
...@@ -85,13 +86,12 @@ class ParallelMLP(MegatronModule): ...@@ -85,13 +86,12 @@ class ParallelMLP(MegatronModule):
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states): def forward(self, hidden_states):
# [s, b, 4hp] # [s, b, 4hp]
...@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule): ...@@ -109,7 +109,7 @@ class ParallelMLP(MegatronModule):
return output, output_bias return output, output_bias
class ParallelSelfAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
...@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule): ...@@ -117,8 +117,10 @@ class ParallelSelfAttention(MegatronModule):
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number): output_layer_init_method, layer_number,
super(ParallelSelfAttention, self).__init__() attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
...@@ -128,20 +130,38 @@ class ParallelSelfAttention(MegatronModule): ...@@ -128,20 +130,38 @@ class ParallelSelfAttention(MegatronModule):
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size, self.hidden_size_per_partition = mpu.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * args.hidden_size, 3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
...@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule): ...@@ -153,8 +173,8 @@ class ParallelSelfAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.fp16,
args.scaled_upper_triang_masked_softmax_fusion, self.attn_mask_type,
args.scaled_masked_softmax_fusion, args.masked_softmax_fusion,
self.attention_mask_func, self.attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -166,14 +186,14 @@ class ParallelSelfAttention(MegatronModule): ...@@ -166,14 +186,14 @@ class ParallelSelfAttention(MegatronModule):
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
args.hidden_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size(); input_shape = mixed_layer.size()
if num_splits_first: if num_splits_first:
"""[s, b, num_splits * np * hn] """[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn] -->(view) [s, b, num_splits, np, hn]
...@@ -203,13 +223,14 @@ class ParallelSelfAttention(MegatronModule): ...@@ -203,13 +223,14 @@ class ParallelSelfAttention(MegatronModule):
return mixed_layer return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
...@@ -232,6 +253,36 @@ class ParallelSelfAttention(MegatronModule): ...@@ -232,6 +253,36 @@ class ParallelSelfAttention(MegatronModule):
(query_layer, (query_layer,
key_layer, key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
...@@ -246,7 +297,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -246,7 +297,6 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value: if get_key_value:
present = (key_layer, value_layer) present = (key_layer, value_layer)
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
# =================================== # ===================================
...@@ -260,6 +310,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -260,6 +310,7 @@ class ParallelSelfAttention(MegatronModule):
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
...@@ -272,15 +323,15 @@ class ParallelSelfAttention(MegatronModule): ...@@ -272,15 +323,15 @@ class ParallelSelfAttention(MegatronModule):
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result, matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ================================================== # ==================================================
# Update attention mask for inference. [b, np, sq, sk] # Update attention mask for inference. [b, np, sq, sk]
# ================================================== # ==================================================
...@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -298,7 +349,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3), :attention_scores.size(3),
:attention_scores.size(3)] :attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
# =========================== # ===========================
...@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -312,7 +362,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
# ========================= # =========================
# Context layer. [sq, b, hp] # Context layer. [sq, b, hp]
# ========================= # =========================
...@@ -335,7 +384,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -335,7 +384,7 @@ class ParallelSelfAttention(MegatronModule):
output_size[2], -1) output_size[2], -1)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
...@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -348,7 +397,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
...@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) : ...@@ -389,16 +437,19 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an Transformer layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number): output_layer_init_method, layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
...@@ -410,30 +461,47 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -410,30 +461,47 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
# Self attention. # Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method, self.self_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method, output_layer_init_method,
layer_number) layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data. # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP # MLP
self.mlp = ParallelMLP(init_method, self.mlp = ParallelMLP(init_method,
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False): encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.attention(layernorm_output, self.self_attention(layernorm_output,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
...@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -459,7 +527,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
...@@ -470,6 +538,28 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -470,6 +538,28 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -479,7 +569,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -479,7 +569,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
...@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule): ...@@ -497,7 +587,9 @@ class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, def __init__(self, attention_mask_func,
init_method, output_layer_init_method): init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
...@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule): ...@@ -516,7 +608,9 @@ class ParallelTransformer(MegatronModule):
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, attention_mask_func, init_method,
output_layer_init_method, layer_number) output_layer_init_method, layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
...@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule): ...@@ -531,14 +625,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
x_ = inputs[0] x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end): for index in range(start, end):
layer = self._get_layer(index) layer = self._get_layer(index)
x_ = layer(x_, inputs[1]) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_ return x_
return custom_forward return custom_forward
...@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule): ...@@ -548,13 +646,13 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers l += self.checkpoint_num_layers
return hidden_states return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks. # Checks.
if layer_past is not None: if layer_past is not None:
...@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule): ...@@ -577,7 +675,9 @@ class ParallelTransformer(MegatronModule):
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask) attention_mask,
encoder_output,
enc_dec_attn_mask)
else: else:
if get_key_value: if get_key_value:
presents = [] presents = []
...@@ -588,6 +688,8 @@ class ParallelTransformer(MegatronModule): ...@@ -588,6 +688,8 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index] past = layer_past[index]
hidden_states = layer(hidden_states, hidden_states = layer(hidden_states,
attention_mask, attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past, layer_past=past,
get_key_value=get_key_value) get_key_value=get_key_value)
if get_key_value: if get_key_value:
......
...@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -141,5 +141,4 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer', args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
'scaled_upper_triang_masked_softmax_fusion': True})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment