Commit 51a2e6b0 authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Jared Casper
Browse files

Various speed optimizations.

parent 12518332
...@@ -19,7 +19,7 @@ import argparse ...@@ -19,7 +19,7 @@ import argparse
import os import os
import torch import torch
from megatron import fused_kernels
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
...@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
# load scaled_upper_triang_masked_softmax_fusion kernel
if args.scaled_upper_triang_masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
_print_args(args) _print_args(args)
return args return args
...@@ -221,6 +225,14 @@ def _add_training_args(parser): ...@@ -221,6 +225,14 @@ def _add_training_args(parser):
'by this value.') 'by this value.')
group.add_argument('--tensorboard-dir', type=str, default=None, group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.') help='Write TensorBoard logs to this directory.')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
action='store_true',
help='Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.')
group.add_argument('--bias-gelu-fusion', action='store_true',
help='Enable bias and gelu fusion.')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')
return parser return parser
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import subprocess
from torch.utils import cpp_extension
def load_scaled_upper_triang_masked_softmax_fusion_kernel():
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute()
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
name='scaled_upper_triang_masked_softmax_cuda',
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag,
verbose=True)
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace {
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale;
} else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
} else if (element_index < element_count) {
dst[i*element_count*stride+it*WARP_SIZE] = 0;
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE];
} else {
output_reg[i][it] = acc_t(0);
}
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it];
} else {
grad_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i]));
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
//backward pass is completely in-place
return output_grads;
}
}
}
}
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = \
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
return input_grads, None
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(self, input_in_fp16, upper_triang_mask,
mask_func, softmax_in_fp32, scale):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask = upper_triang_mask
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, \
'softmax should be in fp32 when scaled'
def forward(self, input, mask):
# [b, np, s, s]
data_size = input.size()
assert input.dim() == 4
# invoke custom kernel for implicit uuper triangular masking
if self.input_in_fp16 and self.upper_triang_mask and \
data_size[-1] <= 2048 and input.size()[2] == input.size()[3]:
input = input.view(-1, data_size[2], data_size[3])
scale = self.scale if self.scale is not None else 1.0
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
if self.input_in_fp16 and self.softmax_in_fp32:
input = input.float()
mask_output = self.mask_func(input, mask)
if self.scale is not None:
mask_output = mask_output * self.scale
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32:
probs = probs.half()
return probs
...@@ -22,7 +22,6 @@ from megatron import get_args ...@@ -22,7 +22,6 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
...@@ -48,13 +47,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -48,13 +47,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
elif args.onnx_safe:
gelu = erf_gelu
if init_method is None: if init_method is None:
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
...@@ -64,7 +56,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -64,7 +56,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
mlp_activation_func=gelu,
init_method=init_method, init_method=init_method,
output_layer_init_method=scaled_init_method, output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -271,7 +262,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -271,7 +262,6 @@ class TransformerLanguageModel(MegatronModule):
def __init__(self, def __init__(self,
attention_mask_func, attention_mask_func,
mlp_activation_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
...@@ -295,8 +285,8 @@ class TransformerLanguageModel(MegatronModule): ...@@ -295,8 +285,8 @@ class TransformerLanguageModel(MegatronModule):
# Transformer # Transformer
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
attention_mask_func, mlp_activation_func, attention_mask_func, self.init_method,
self.init_method, output_layer_init_method) output_layer_init_method)
self._transformer_key = 'transformer' self._transformer_key = 'transformer'
# Pooler # Pooler
......
...@@ -17,12 +17,21 @@ ...@@ -17,12 +17,21 @@
import math import math
import torch import torch
import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.mpu import LayerNorm from megatron.mpu import LayerNorm
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -34,7 +43,7 @@ from megatron.module import MegatronModule ...@@ -34,7 +43,7 @@ from megatron.module import MegatronModule
b: batch size b: batch size
s: sequence length s: sequence length
l: number of layers l: number of layers
Transformer takes input of size [b, s, h] and returns a Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments: tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores` attention_mask_func: a function that takes `unmaksed-attention-scores`
...@@ -45,7 +54,6 @@ from megatron.module import MegatronModule ...@@ -45,7 +54,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask) unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -55,8 +63,7 @@ class ParallelMLP(MegatronModule): ...@@ -55,8 +63,7 @@ class ParallelMLP(MegatronModule):
applied. applied.
""" """
def __init__(self, mlp_activation_func, init_method, def __init__(self, init_method, output_layer_init_method):
output_layer_init_method):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
...@@ -65,29 +72,40 @@ class ParallelMLP(MegatronModule): ...@@ -65,29 +72,40 @@ class ParallelMLP(MegatronModule):
args.hidden_size, args.hidden_size,
4 * args.hidden_size, 4 * args.hidden_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
skip_bias_add=True)
self.activation_func = mlp_activation_func self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size, 4 * args.hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method) init_method=output_layer_init_method,
skip_bias_add=True)
self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
# [b, s, 4hp] # [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [b, s, h] if self.bias_gelu_fusion:
output = self.dense_4h_to_h(intermediate_parallel) intermediate_parallel = \
output = self.dropout(output) bias_gelu_impl(intermediate_parallel, bias_parallel)
return output else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelSelfAttention(MegatronModule): class ParallelSelfAttention(MegatronModule):
...@@ -123,10 +141,22 @@ class ParallelSelfAttention(MegatronModule): ...@@ -123,10 +141,22 @@ class ParallelSelfAttention(MegatronModule):
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * args.hidden_size, 3 * args.hidden_size,
stride=3,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
self.attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
# on average it should not be partition dependent. # on average it should not be partition dependent.
...@@ -137,110 +167,85 @@ class ParallelSelfAttention(MegatronModule): ...@@ -137,110 +167,85 @@ class ParallelSelfAttention(MegatronModule):
args.hidden_size, args.hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method) init_method=output_layer_init_method,
self.output_dropout = torch.nn.Dropout(args.hidden_dropout) skip_bias_add=True)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
"""
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff = 1
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
if self.apply_query_key_layer_scaling:
attention_scores = attention_scores * self.layer_number
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
return attention_probs
def _get_attended_context(self, attention_probs, value_layer): def forward(self, hidden_states, attention_mask, layer_past=None,
"""Final attended tesnor and transposed back to [b, s, hp].""" get_key_value=False):
# Context layer. # hidden_states: [s, b, h]
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer # =====================
# Query, Key, and Value
# =====================
def _get_output(self, context_layer): # Attention heads [s, b, hp] --> [s, b, 3 * hp]
"""Output layer with dropout.""" mixed_x_layer, _ = self.query_key_value(hidden_states)
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
return output # [s, b, 3 * hp] --> [s, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
# Attention heads. [b, np, s, hn] # ==================================
query_layer, key_layer, value_layer = self._get_query_key_value( # Adjust key and value for inference
hidden_states) # ==================================
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=-2) key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=-2) value_layer), dim=0)
if get_key_value: if get_key_value:
present = (key_layer, value_layer) present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = self._get_unmasked_attention_scores(
query_layer, key_layer)
# fp32 conversion. # ===================================
if self.fp16 and self.attention_softmax_in_fp32: # Raw attention scores. [b, np, s, s]
attention_scores = attention_scores.float() # ===================================
# [b, np, s, s]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [s, b, np, hn] -> [s, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, s, s]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, s, s]
matmul_result = torch.baddbmm(matmul_result,
query_layer.transpose(0, 1), # [b * np, s, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, s]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, s, s]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, s, s]
# ==================================================
# Apply attention mask. [b, np, s, s]
if get_key_value: if get_key_value:
with torch.no_grad(): with torch.no_grad():
if layer_past is not None: if layer_past is not None:
...@@ -253,26 +258,93 @@ class ParallelSelfAttention(MegatronModule): ...@@ -253,26 +258,93 @@ class ParallelSelfAttention(MegatronModule):
..., ...,
:attention_scores.size(3), :attention_scores.size(3),
:attention_scores.size(3)] :attention_scores.size(3)]
attention_scores = self.attention_mask_func(attention_scores,
attention_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion # ===========================
if self.fp16 and self.attention_softmax_in_fp32: # Attention probs and dropout
attention_probs = attention_probs.half() # ===========================
# Context layer. [b, s, hp] # attention scores and attention mask [b, np, s, s]
context_layer = self._get_attended_context(attention_probs, value_layer) attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# Output. [b, s, h] # This is actually dropping out entire tokens to attend to, which might
output = self._get_output(context_layer) # seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [s, b, hp]
# =========================
# value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn]
# context layer shape: [b, np, s, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
value_layer.size(0),
value_layer.size(3))
# change view [s, b * np, hn]
value_layer = value_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# change view [b * np, s, s]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, s, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
# change view [b, np, s, hn]
context_layer = context_layer.view(*output_size)
# [b, np, s, hn] --> [s, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [s, b, np, hn] --> [s, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [s, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value: if get_key_value:
output = [output, present] output = [output, present]
return output return output, bias
def bias_dropout_add(x, bias, residual, prob, training) :
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
...@@ -282,8 +354,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -282,8 +354,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, mlp_activation_func, def __init__(self, attention_mask_func, init_method,
init_method, output_layer_init_method, layer_number): output_layer_init_method, layer_number):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
...@@ -301,6 +373,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -301,6 +373,8 @@ class ParallelTransformerLayer(MegatronModule):
self.attention = ParallelSelfAttention(attention_mask_func, init_method, self.attention = ParallelSelfAttention(attention_mask_func, init_method,
output_layer_init_method, output_layer_init_method,
layer_number) layer_number)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data. # Layernorm on the input data.
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
...@@ -308,7 +382,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -308,7 +382,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
# MLP # MLP
self.mlp = ParallelMLP(mlp_activation_func, init_method, self.mlp = ParallelMLP(init_method,
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
...@@ -318,28 +392,60 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -318,28 +392,60 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the begining of the transformer layer. # Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output = self.attention(layernorm_output, attention_output, attention_bias = \
attention_mask, self.attention(layernorm_output,
layer_past=layer_past, attention_mask,
get_key_value=get_key_value) layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value: if get_key_value:
attention_output, presents = attention_output attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
layernorm_input = layernorm_output + attention_output residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else: else:
layernorm_input = hidden_states + attention_output bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP. # MLP.
mlp_output = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection. # Second residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
output = layernorm_output + mlp_output residual = layernorm_output
else: else:
output = layernorm_input + mlp_output residual = layernorm_input
#re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
...@@ -350,7 +456,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -350,7 +456,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, mlp_activation_func, def __init__(self, attention_mask_func,
init_method, output_layer_init_method): init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
...@@ -371,8 +477,8 @@ class ParallelTransformer(MegatronModule): ...@@ -371,8 +477,8 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, mlp_activation_func, attention_mask_func, init_method,
init_method, output_layer_init_method, layer_number) output_layer_init_method, layer_number)
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)]) [build_layer(i + 1) for i in range(self.num_unique_layers)])
...@@ -435,6 +541,9 @@ class ParallelTransformer(MegatronModule): ...@@ -435,6 +541,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask) attention_mask)
...@@ -453,6 +562,9 @@ class ParallelTransformer(MegatronModule): ...@@ -453,6 +562,9 @@ class ParallelTransformer(MegatronModule):
if get_key_value: if get_key_value:
hidden_states, present = hidden_states hidden_states, present = hidden_states
presents.append(present) presents.append(present)
# reverting data format change [s b h] --> [b s h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm. # Final layer norm.
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
......
...@@ -54,7 +54,7 @@ def _initialize_affine_weight_gpu(weight, init_method, ...@@ -54,7 +54,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
weight.model_parallel = True weight.model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
with get_cuda_rng_tracker().fork(): with get_cuda_rng_tracker().fork():
init_method(weight) init_method(weight)
...@@ -186,11 +186,15 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -186,11 +186,15 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
""" """
def __init__(self, input_size, output_size, bias=True, gather_output=True, def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False): keep_master_weight_for_test=False,
skip_bias_add=False):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
...@@ -200,6 +204,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -200,6 +204,7 @@ class ColumnParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size() world_size = get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
...@@ -245,13 +250,16 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -245,13 +250,16 @@ class ColumnParallelLinear(torch.nn.Module):
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_) input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, self.bias)
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel) output = gather_from_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
return output output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
...@@ -279,12 +287,16 @@ class RowParallelLinear(torch.nn.Module): ...@@ -279,12 +287,16 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
""" """
def __init__(self, input_size, output_size, bias=True, def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False): keep_master_weight_for_test=False,
skip_bias_add=False):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
# Keep input parameters # Keep input parameters
...@@ -294,6 +306,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -294,6 +306,7 @@ class RowParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size() world_size = get_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
...@@ -340,8 +353,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -340,8 +353,11 @@ class RowParallelLinear(torch.nn.Module):
output_parallel = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel) output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None: if not self.skip_bias_add:
output = output_ + self.bias output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else: else:
output = output_ output = output_
return output output_bias = self.bias
return output, output_bias
...@@ -236,29 +236,35 @@ def backward_step(optimizer, model, loss): ...@@ -236,29 +236,35 @@ def backward_step(optimizer, model, loss):
timers = get_timers() timers = get_timers()
# Backward pass. # Backward pass.
timers('backward-backward').start()
optimizer.zero_grad(set_grads_to_None=True) optimizer.zero_grad(set_grads_to_None=True)
if args.fp16: if args.fp16:
optimizer.backward(loss, update_master_grads=False) optimizer.backward(loss, update_master_grads=False)
else: else:
loss.backward() loss.backward()
timers('backward-backward').stop()
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('allreduce').start() timers('backward-allreduce').start()
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop() timers('backward-allreduce').stop()
# Update master gradients. # Update master gradients.
timers('backward-master-grad').start()
if args.fp16: if args.fp16:
optimizer.update_master_grads() optimizer.update_master_grads()
timers('backward-master-grad').stop()
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
if args.clip_grad > 0: if args.clip_grad > 0:
if not args.fp16: if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad) mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
...@@ -311,7 +317,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -311,7 +317,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers_to_log.append(name) timers_to_log.append(name)
add_to_logging('forward') add_to_logging('forward')
add_to_logging('backward') add_to_logging('backward')
add_to_logging('allreduce') add_to_logging('backward-backward')
add_to_logging('backward-allreduce')
add_to_logging('backward-master-grad')
add_to_logging('backward-clip-grad')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch generator') add_to_logging('batch generator')
......
...@@ -46,9 +46,9 @@ def report_memory(name): ...@@ -46,9 +46,9 @@ def report_memory(name):
torch.cuda.memory_allocated() / mega_bytes) torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format( string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes) torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
string += ' | max cached: {}'.format( string += ' | max reserved: {}'.format(
torch.cuda.max_memory_cached() / mega_bytes) torch.cuda.max_memory_reserved() / mega_bytes)
print_rank_0(string) print_rank_0(string)
......
...@@ -28,7 +28,6 @@ from megatron.training import pretrain ...@@ -28,7 +28,6 @@ from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
...@@ -80,10 +79,8 @@ def forward_step(data_iterator, model): ...@@ -80,10 +79,8 @@ def forward_step(data_iterator, model):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator) data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels) losses = model(tokens, position_ids, attention_mask, labels=labels)
loss_mask = loss_mask.view(-1) loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment