Commit e4575be9 authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #489 canceled with stage
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""For backward compatibility, we need the class definitions to deserialize."""
class LossScaler:
def __init__(self, scale=1):
self.cur_scale = scale
class DynamicLossScaler:
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import subprocess
from torch.utils import cpp_extension
def load(args):
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
#
# but if a user wants to set an explicit list of archs to compile to, then let that list
# through:
arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
if arch_list is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
# # Check if cuda 11 is installed for compute capability 8.0
# cc_flag = []
# _, bare_metal_major, _ = _get_cuda_bare_metal_version(
# cpp_extension.CUDA_HOME)
# if int(bare_metal_major) >= 11:
# cc_flag.append('-gencode')
# cc_flag.append('arch=compute_80,code=sm_80')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
buildpath.mkdir(parents=True, exist_ok=True)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3'] + extra_cuda_flags,
verbose=(args.rank == 0)
)
# '-gencode', 'arch=compute_70,code=sm_70',
# ==============
# Fused softmax.
# ==============
if args.masked_softmax_fusion:
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__']
# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources, extra_cuda_flags)
# Masked softmax.
sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
# =================================
# Mixed precision fused layer norm.
# =================================
extra_cuda_flags = []
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "--version"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("version:") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine,
"LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template<typename U> __device__
void cuChanOnlineSum(
const U muB,
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA*mu + nB*muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu= U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2;
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
}
}
}
template<> __device__
void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu= float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2;
int l = 8*thrx;
if ((((size_t)lvals)&3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
}
}
}
template<typename U> __device__ U rsqrt(U v) {
return U(1) / sqrt(v);
}
template<> __device__ float rsqrt(float v) {
return rsqrtf(v);
}
template<> __device__ double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
extern __shared__ float s_float[];
return s_float;
}
};
}
template<typename T, typename U, typename V> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = (curr - mu) * c_invvar * static_cast<U>(gamma[i]) + static_cast<U>(beta[i]);
}
} else {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
__syncthreads();
}
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename U, typename V> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
V* grad_gamma,
V* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputeGradInput(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
// prevent race where buf is written again before reads are done
__syncthreads();
}
}
template<typename T, typename U, typename V>
void HostApplyLayerNorm(
V* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
const V* beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(
output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
template<typename T, typename U, typename V>
void HostLayerNormGradient(
const V* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const V* gamma,
const V* beta,
double epsilon,
T* grad_input,
V* grad_gamma,
V* grad_beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
// !!! This is a file automatically generated by hipify!!!
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/DeviceUtils.cuh"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include "type_shim.h"
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template<typename U> __device__
void cuChanOnlineSum(
const U muB,
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA*mu + nB*muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu= U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2;
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
}
}
}
template<> __device__
void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu= float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2;
int l = 8*thrx;
if ((((size_t)lvals)&3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
}
}
}
template<typename U> __device__ U rsqrt(U v) {
return U(1) / sqrt(v);
}
template<> __device__ float rsqrt(float v) {
return rsqrtf(v);
}
template<> __device__ double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
}
template<typename T, typename U, typename V> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = (curr - mu) * c_invvar * static_cast<U>(gamma[i]) + static_cast<U>(beta[i]);
}
} else {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
__syncthreads();
}
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename U, typename V> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
V* grad_gamma,
V* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputeGradInput(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
// prevent race where buf is written again before reads are done
__syncthreads();
}
}
template<typename T, typename U, typename V>
void HostApplyLayerNorm(
V* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
const V* beta
)
{
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, ::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
hipLaunchKernelGGL(( cuApplyLayerNorm), dim3(blocks), dim3(threads), nshared, stream,
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(
output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
template<typename T, typename U, typename V>
void HostLayerNormGradient(
const V* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const V* gamma,
const V* beta,
double epsilon,
T* grad_input,
V* grad_gamma,
V* grad_beta
)
{
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
hipLaunchKernelGGL(( cuComputePartGradGammaBeta), dim3(blocks2), dim3(threads2), nshared2, stream,
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
hipLaunchKernelGGL(( cuComputeGradGammaBeta), dim3(blocks3), dim3(threads3), nshared3, stream,
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, ::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
hipLaunchKernelGGL(( cuComputeGradInput), dim3(blocks1), dim3(threads1), nshared, stream,
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define __HIP_PLATFORM_HCC__
#include <hip/hip_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#define __HIP_PLATFORM_HCC__
#include <hip/hip_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
// #include <cuda_fp16.h>
#include <hip/hip_fp16.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor(value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (elements[i][it] <= -std::numeric_limits<acc_t>::infinity()) {
elements[i][it] = 0.0f;
} else {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
}
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
if (sum[i] == 0.0f) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE);
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// #include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
// !!! This is a file automatically generated by hipify!!!
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #include <cuda_profiler_api.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define __HIP_PLATFORM_HCC__
#include <hip/hip_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#define __HIP_PLATFORM_HCC__
#include <hip/hip_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor(value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::hip::getCurrentHIPStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// #include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
// !!! This is a file automatically generated by hipify!!!
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// #include <cuda_profiler_api.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
import math
import torch
from torch.nn import LayerNorm
from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func
def test_load_fused_kernels():
try:
import fused_mix_prec_layer_norm_cuda
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
import torch
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
embedding_output = bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
# (bsz, 1, 1, seq_len)
mask = bert.get_extended_attention_mask(
attention_mask=tokens["attention_mask"].cuda(),
input_shape=tokens["input_ids"].shape,
device=bert.device,
)
# (bsz, 1, seq_len, seq_len)
mask = mask.repeat(1, 1, mask.size()[-1], 1)
attention = bert.encoder.layer[0].attention.self
key_layer = attention.transpose_for_scores(attention.key(embedding_output))
query_layer = attention.transpose_for_scores(attention.query(embedding_output))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores /= math.sqrt(key_layer.size()[-1])
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attention_scores,
(mask != 0),
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attention_scores,
(mask != 0),
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_fused_upper_triangle_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
attn = gpt.h[0]
hidden_states = gpt.wte(tokens["input_ids"].cuda())
q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))
sq, sk = q.size(-2), k.size(-2)
causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
# [bsz, seq_len, d_model]
embedding_output = (
bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
.cuda()
.half()
)
fused_layernorm_layer = (
MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
torch_layernorm_layer = (
LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
fused_output = fused_layernorm_layer(embedding_output)
torch_output = torch_layernorm_layer(embedding_output)
test_result = (fused_output - torch_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_layer_norm"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_layer_norm"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
if __name__ == "__main__":
try:
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import transformers
transformers.logging.set_verbosity(
transformers.logging.FATAL,
)
except:
print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1)
test_load_fused_kernels()
test_fused_softmax()
test_fused_upper_triangle_mask_softmax()
test_layer_norm()
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import functools
import os
import sys
import time
from packaging import version
from pathlib import Path
import torch
from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_CODECARBON_TRACKER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
def get_tokenizer():
"""Return tokenizer."""
_ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
return _GLOBAL_TOKENIZER
def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER
def get_codecarbon_tracker():
"""Return codecarbon tracker. It can be None so no need
to check if it is initialized."""
return _GLOBAL_CODECARBON_TRACKER
def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return _GLOBAL_ADLR_AUTORESUME
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
if args.vocab_file or args.tokenizer_name_or_path:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_codecarbon_tracker(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'num microbatches calculator')
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
args)
def _build_tokenizer(args):
"""Initialize tokenizer."""
global _GLOBAL_TOKENIZER
_ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
_GLOBAL_TOKENIZER = build_tokenizer(args)
return _GLOBAL_TOKENIZER
def rebuild_tokenizer(args):
global _GLOBAL_TOKENIZER
_GLOBAL_TOKENIZER = None
return _build_tokenizer(args)
def _set_tensorboard_writer(args):
"""Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size - 1):
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir,
max_queue=args.tensorboard_queue_size)
# this is supposed to make the data load in TB faster
if version.parse(torch.__version__) >= version.parse("1.9"):
_GLOBAL_TENSORBOARD_WRITER.add_scalar = functools.partial(
_GLOBAL_TENSORBOARD_WRITER.add_scalar, new_style=True
)
except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.', flush=True)
# Important: the codecarbon is very unstable and its latest incarnation using the python scheduler interferes with the asyncio library we use in the test suite which breaks everything, so making this a no-op for now.
def _set_codecarbon_tracker(args):
return # turned off
global _GLOBAL_CODECARBON_TRACKER
if not hasattr(args, 'codecarbon_dir') or args.codecarbon_dir is None:
return
import codecarbon
if args.rank == 0:
print('> setting codecarbon ...')
output_dir = args.codecarbon_dir
output_file = f"emissions-{args.rank:03d}.csv"
logger_preamble = f"r{args.rank:03d}"
log_level = "error"
country_iso_code = "FRA"
# CC was emitting all kinds of warnings about issues with measurements, so the following
# settings are supposed to help
misfire_grace_time = 3
measure_power_secs = 60
max_instances = 3
Path(output_dir).mkdir(parents=True, exist_ok=True)
_GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker(
output_dir=output_dir,
output_file=output_file,
logger_preamble=logger_preamble,
log_level=log_level,
misfire_grace_time=misfire_grace_time,
measure_power_secs=measure_power_secs,
max_instances=max_instances,
country_iso_code=country_iso_code,
)
def codecarbon_tracker_start():
return # turned off, see the notes above
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return
#print("CC START")
_GLOBAL_CODECARBON_TRACKER.start()
def codecarbon_tracker_stop():
return # turned off, see the notes above
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return
#print("CC STOP")
_GLOBAL_CODECARBON_TRACKER.stop()
def codecarbon_tracker_flush():
return # turned off, see the notes above
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return
#print("CC FLUSH")
_GLOBAL_CODECARBON_TRACKER.flush()
def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
if args.adlr_autoresume:
if args.rank == 0:
print('enabling autoresume ...', flush=True)
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
_GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(f'time/{name}-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = ''
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if not len(string):
return
string = 'time (ms)' + string
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
import sys
import torch
import torch.distributed as dist
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import get_model
class IndexBuilder(object):
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def __init__(self):
args = get_args()
self.model = None
self.dataloader = None
self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert not (args.load and args.ict_load)
#self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
self.load_attributes()
self.is_main_builder = mpu.get_data_parallel_rank() == 0
self.num_total_builders = mpu.get_data_parallel_world_size()
self.iteration = self.total_processed = 0
def load_attributes(self):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
only_context_model = True
if self.biencoder_shared_query_context_model:
only_context_model = False
model = get_model(lambda: biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
assert len(self.model) == 1
self.model[0].eval()
self.dataset = get_open_retrieval_wiki_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
self.batch_size))
self.evidence_embedder_obj = OpenRetreivalDataStore( \
load_from_path=False)
def track_and_report_progress(self, batch_size):
"""
Utility function for tracking progress
"""
self.iteration += 1
self.total_processed += batch_size * self.num_total_builders
if self.is_main_builder and self.iteration % self.log_interval == 0:
print('Batch {:10d} | Total {:10d}'.format(self.iteration,
self.total_processed), flush=True)
def build_and_save_index(self):
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
while True:
try:
# batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \
context_pad_mask = get_open_retrieval_batch( \
self.dataloader)
except (StopIteration, IndexError):
break
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool
context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
context_types)
context_logits = detach(context_logits)
row_id = detach(row_id)
self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id))
# This process signals to finalize its shard and then synchronize with
# the other processes
self.evidence_embedder_obj.save_shard()
torch.distributed.barrier()
del self.model
# rank 0 process builds the final copy
if self.is_main_builder:
self.evidence_embedder_obj.merge_shards_and_save()
# make sure that every single piece of data was embedded
assert len(self.evidence_embedder_obj.embed_data) == \
len(self.dataset)
self.evidence_embedder_obj.clear()
# complete building the final copy
torch.distributed.barrier()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization."""
import random
import os
import sys
import time
import numpy as np
import torch
import logging as lg
import subprocess
from megatron import fused_kernels, logging
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
import deepspeed
def git_ds_info():
args = get_args()
if not args.deepspeed:
return
from deepspeed.env_report import main as ds_report
ds_report(hide_operator_status=True, hide_errors_and_warnings=True)
def command_exists(cmd):
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
return result.wait() == 0
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD 2>&1"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD 2>&1"
if command_exists('git'):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****')
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# torch.distributed initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed()
# Random seeds for reproducibility.
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
def set_verbosity(logging_level: str):
log_level = logging.log_levels[logging_level]
logging.set_verbosity(log_level)
logging.disable_default_handler()
handler = lg.StreamHandler(sys.stdout)
handler.setLevel(log_level)
handler.flush = sys.stderr.flush
logging.add_handler(handler)
def set_verbosity_deepspeed(logging_level: str):
if not args.deepspeed:
return
from deepspeed.utils import logger as ds_logger
log_level = logging.log_levels[logging_level]
ds_logger.setLevel(log_level)
def set_verbosity_transformers(logging_level: str):
try:
# XXX: perhaps we need a better way of knowing when to override transformers logging
# currently it's only when using `--tokenizer-type PretrainedFromHF`
from transformers.utils import logging as transformers_logging
log_level = logging.log_levels[logging_level]
logging.set_verbosity(log_level)
except:
pass
if args.rank == 0:
if args.log_level is not None:
set_verbosity(args.log_level)
set_verbosity_deepspeed(args.log_level)
set_verbosity_transformers(args.log_level)
else:
if args.log_level_replica is not None:
set_verbosity(args.log_level_replica)
set_verbosity_deepspeed(args.log_level_replica)
set_verbosity_transformers(args.log_level_replica)
_set_random_seed(args.seed)
args = get_args()
if args.rank == 0:
git_ds_info()
if args.lazy_mpu_init:
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume.
_init_autoresume()
# Compile dependencies.
_compile_dependencies()
# No continuation function
return None
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
from megatron.data.dataset_utils import compile_helper
compile_helper()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
custom_kernel_constraint and
args.masked_softmax_fusion):
if args.rank == 0:
error = "constraints for invoking optimized fused softmax kernel are not met"
if args.abort_on_unmet_fused_kernel_constraints:
sys.exit(f"\n\nERROR: {error} and --abort-on-unmet-fused-kernel-constraints was passed. Aborting.\n\n")
else:
print(f'WARNING: {error}. We default back to unfused kernel invocations.', flush=True)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
import warnings
with warnings.catch_warnings():
# ignore loading noise
warnings.simplefilter("ignore")
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'.format(
time.time() - start_time), flush=True)
def setup_deepspeed_random_and_activation_checkpointing(args):
'''Optional DeepSpeed Activation Checkpointing features.
Gives access to partition activations, contiguous memory optimizations
and cpu checkpointing.
Activation checkpoint requires keep track of the random states
and setting the random seed for each MP process. Megatron uses
mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
for keeping track of the random states and setting the random seeds.
Since they are used in places outside of activation checkpointing,
we overwrite them to maintain consistency.
This must be called before all the calls to mpu.model_parallel_cuda_manual_seed
'''
num_layers = args.num_layers // args.checkpoint_num_layers
num_layers = num_layers if args.num_layers % args.checkpoint_num_layers == 0 else num_layers + 1
if args.split_transformers:
num_layers *= 2
deepspeed.checkpointing.configure(
mpu,
partition_activations=args.partition_activations,
contiguous_checkpointing=args.contigious_checkpointing,
num_checkpoints=num_layers,
checkpoint_in_cpu=args.checkpoint_in_cpu,
synchronize=args.synchronize_each_layer,
profile=args.profile_backward)
mpu.checkpoint = deepspeed.checkpointing.checkpoint
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():
if args.rank == 0:
print('torch distributed is already initialized, '
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
deepspeed.init_distributed(args.distributed_backend)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
if args.deepspeed and args.deepspeed_activation_checkpointing:
setup_deepspeed_random_and_activation_checkpointing(args)
def _init_autoresume():
"""Set autoresume start time."""
autoresume = get_adlr_autoresume()
if autoresume:
torch.distributed.barrier()
autoresume.init()
torch.distributed.barrier()
def _set_random_seed(seed_):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
def write_args_to_tensorboard():
"""Write arguments to tensorboard."""
args = get_args()
writer = get_tensorboard_writer()
if writer:
for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration)
def log_restart_to_tensorboard():
"""
Log new start and world size - the key is to denote a restart, and use world_size as another
useful info which can help to track changes in resource allocation.
"""
args = get_args()
writer = get_tensorboard_writer()
if writer:
# emulate a blip to avoid flatline
writer.add_scalar('iteration-time/world_size', args.world_size, args.iteration)
writer.add_scalar('iteration-time/world_size', 0, args.iteration+1)
def _initialize_mem_buffs():
"""Initialize manually allocated static memory."""
args = get_args()
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment