Commit 51a2e6b0 authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Jared Casper
Browse files

Various speed optimizations.

parent 12518332
......@@ -19,7 +19,7 @@ import argparse
import os
import torch
from megatron import fused_kernels
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
......@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
# load scaled_upper_triang_masked_softmax_fusion kernel
if args.scaled_upper_triang_masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
_print_args(args)
return args
......@@ -221,6 +225,14 @@ def _add_training_args(parser):
'by this value.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
action='store_true',
help='Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.')
group.add_argument('--bias-gelu-fusion', action='store_true',
help='Enable bias and gelu fusion.')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')
return parser
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import subprocess
from torch.utils import cpp_extension
def load_scaled_upper_triang_masked_softmax_fusion_kernel():
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute()
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
name='scaled_upper_triang_masked_softmax_cuda',
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag,
verbose=True)
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace {
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale;
} else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
} else if (element_index < element_count) {
dst[i*element_count*stride+it*WARP_SIZE] = 0;
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE];
} else {
output_reg[i][it] = acc_t(0);
}
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it];
} else {
grad_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i]));
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
//backward pass is completely in-place
return output_grads;
}
}
}
}
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = \
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
return input_grads, None
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(self, input_in_fp16, upper_triang_mask,
mask_func, softmax_in_fp32, scale):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask = upper_triang_mask
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, \
'softmax should be in fp32 when scaled'
def forward(self, input, mask):
# [b, np, s, s]
data_size = input.size()
assert input.dim() == 4
# invoke custom kernel for implicit uuper triangular masking
if self.input_in_fp16 and self.upper_triang_mask and \
data_size[-1] <= 2048 and input.size()[2] == input.size()[3]:
input = input.view(-1, data_size[2], data_size[3])
scale = self.scale if self.scale is not None else 1.0
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
if self.input_in_fp16 and self.softmax_in_fp32:
input = input.float()
mask_output = self.mask_func(input, mask)
if self.scale is not None:
mask_output = mask_output * self.scale
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32:
probs = probs.half()
return probs
......@@ -22,7 +22,6 @@ from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
......@@ -48,13 +47,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
"""Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
elif args.onnx_safe:
gelu = erf_gelu
if init_method is None:
init_method = init_method_normal(args.init_method_std)
......@@ -64,7 +56,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
# Language model.
language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func,
mlp_activation_func=gelu,
init_method=init_method,
output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes,
......@@ -271,7 +262,6 @@ class TransformerLanguageModel(MegatronModule):
def __init__(self,
attention_mask_func,
mlp_activation_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
......@@ -295,8 +285,8 @@ class TransformerLanguageModel(MegatronModule):
# Transformer
self.transformer = ParallelTransformer(
attention_mask_func, mlp_activation_func,
self.init_method, output_layer_init_method)
attention_mask_func, self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'
# Pooler
......
......@@ -17,12 +17,21 @@
import math
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file:
h: hidden size
......@@ -34,7 +43,7 @@ from megatron.module import MegatronModule
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [b, s, h] and returns a
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
......@@ -45,7 +54,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -55,8 +63,7 @@ class ParallelMLP(MegatronModule):
applied.
"""
def __init__(self, mlp_activation_func, init_method,
output_layer_init_method):
def __init__(self, init_method, output_layer_init_method):
super(ParallelMLP, self).__init__()
args = get_args()
......@@ -65,29 +72,40 @@ class ParallelMLP(MegatronModule):
args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=init_method)
init_method=init_method,
skip_bias_add=True)
self.activation_func = mlp_activation_func
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout)
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(self, hidden_states):
# [b, s, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
# [b, s, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class ParallelSelfAttention(MegatronModule):
......@@ -123,10 +141,22 @@ class ParallelSelfAttention(MegatronModule):
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * args.hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
self.attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
......@@ -137,110 +167,85 @@ class ParallelSelfAttention(MegatronModule):
args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
"""
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff = 1
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
if self.apply_query_key_layer_scaling:
attention_scores = attention_scores * self.layer_number
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
init_method=output_layer_init_method,
skip_bias_add=True)
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [s, b, h]
return context_layer
# =====================
# Query, Key, and Value
# =====================
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
# Attention heads [s, b, hp] --> [s, b, 3 * hp]
mixed_x_layer, _ = self.query_key_value(hidden_states)
return output
# [s, b, 3 * hp] --> [s, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn]
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
# Attention heads. [b, np, s, hn]
query_layer, key_layer, value_layer = self._get_query_key_value(
hidden_states)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=-2)
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=-2)
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = self._get_unmasked_attention_scores(
query_layer, key_layer)
# fp32 conversion.
if self.fp16 and self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, s, s]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [s, b, np, hn] -> [s, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, s, s]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, s, s]
matmul_result = torch.baddbmm(matmul_result,
query_layer.transpose(0, 1), # [b * np, s, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, s]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, s, s]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, s, s]
# ==================================================
# Apply attention mask. [b, np, s, s]
if get_key_value:
with torch.no_grad():
if layer_past is not None:
......@@ -253,26 +258,93 @@ class ParallelSelfAttention(MegatronModule):
...,
:attention_scores.size(3),
:attention_scores.size(3)]
attention_scores = self.attention_mask_func(attention_scores,
attention_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion
if self.fp16 and self.attention_softmax_in_fp32:
attention_probs = attention_probs.half()
# ===========================
# Attention probs and dropout
# ===========================
# Context layer. [b, s, hp]
context_layer = self._get_attended_context(attention_probs, value_layer)
# attention scores and attention mask [b, np, s, s]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# Output. [b, s, h]
output = self._get_output(context_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [s, b, hp]
# =========================
# value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn]
# context layer shape: [b, np, s, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
value_layer.size(0),
value_layer.size(3))
# change view [s, b * np, hn]
value_layer = value_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# change view [b * np, s, s]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, s, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
# change view [b, np, s, hn]
context_layer = context_layer.view(*output_size)
# [b, np, s, hn] --> [s, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [s, b, np, hn] --> [s, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [s, b, h]
# =================
output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output
return output, bias
def bias_dropout_add(x, bias, residual, prob, training) :
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
......@@ -282,8 +354,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
......@@ -301,6 +373,8 @@ class ParallelTransformerLayer(MegatronModule):
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
output_layer_init_method,
layer_number)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
......@@ -308,7 +382,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(mlp_activation_func, init_method,
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
......@@ -318,28 +392,60 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
attention_output, attention_bias = \
self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
layernorm_input = layernorm_output + attention_output
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
layernorm_input = hidden_states + attention_output
bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
output = layernorm_output + mlp_output
residual = layernorm_output
else:
output = layernorm_input + mlp_output
residual = layernorm_input
#re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
if get_key_value:
output = [output, presents]
......@@ -350,7 +456,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, attention_mask_func, mlp_activation_func,
def __init__(self, attention_mask_func,
init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -371,8 +477,8 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number)
attention_mask_func, init_method,
output_layer_init_method, layer_number)
self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)])
......@@ -435,6 +541,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
......@@ -453,6 +562,9 @@ class ParallelTransformer(MegatronModule):
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# reverting data format change [s b h] --> [b s h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm.
output = self.final_layernorm(hidden_states)
......
......@@ -54,7 +54,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
with get_cuda_rng_tracker().fork():
init_method(weight)
......@@ -186,11 +186,15 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
keep_master_weight_for_test=False,
skip_bias_add=False):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
......@@ -200,6 +204,7 @@ class ColumnParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
......@@ -245,13 +250,16 @@ class ColumnParallelLinear(torch.nn.Module):
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, self.bias)
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module):
......@@ -279,12 +287,16 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
keep_master_weight_for_test=False,
skip_bias_add=False):
super(RowParallelLinear, self).__init__()
# Keep input parameters
......@@ -294,6 +306,7 @@ class RowParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
......@@ -340,8 +353,11 @@ class RowParallelLinear(torch.nn.Module):
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None:
output = output_ + self.bias
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
return output
output_bias = self.bias
return output, output_bias
......@@ -236,29 +236,35 @@ def backward_step(optimizer, model, loss):
timers = get_timers()
# Backward pass.
timers('backward-backward').start()
optimizer.zero_grad(set_grads_to_None=True)
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
loss.backward()
timers('backward-backward').stop()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('allreduce').start()
timers('backward-allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
timers('backward-allreduce').stop()
# Update master gradients.
timers('backward-master-grad').start()
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
def train_step(forward_step_func, data_iterator,
......@@ -311,7 +317,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers_to_log.append(name)
add_to_logging('forward')
add_to_logging('backward')
add_to_logging('allreduce')
add_to_logging('backward-backward')
add_to_logging('backward-allreduce')
add_to_logging('backward-master-grad')
add_to_logging('backward-clip-grad')
add_to_logging('optimizer')
add_to_logging('batch generator')
......
......@@ -46,9 +46,9 @@ def report_memory(name):
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached() / mega_bytes)
string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
print_rank_0(string)
......
......@@ -28,7 +28,6 @@ from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses
def model_provider():
"""Build the model."""
......@@ -80,10 +79,8 @@ def forward_step(data_iterator, model):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment