Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
// !!! This is a file automatically generated by hipify!!!
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cuda_bf16.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
void fused_fp32_to_bf16_sr_cuda(at::Tensor & input, at::Tensor & output);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void fused_fp32_to_bf16_sr(at::Tensor & input, at::Tensor & output) {
CHECK_INPUT(input);
CHECK_INPUT(output);
int64_t num_elem = input.numel();
AT_ASSERTM(output.numel() == num_elem, "number of elements in input ond output tensors should be equal");
fused_fp32_to_bf16_sr_cuda(input, output);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp32_to_bf16_sr", &fused_fp32_to_bf16_sr, "fused fp32 to bf16 random rounding");
}
#include <torch/extension.h>
#include <vector>
#include <cassert>
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2)
{
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input
);
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_input);
return grad_input;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &layer_norm, "LayerNorm fast forward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm fast backward (CUDA)");
}
#include <torch/extension.h>
#include <vector>
#include <cassert>
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2)
{
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
std::vector<at::Tensor> layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_gamma,&grad_beta);
return {grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward", &layer_norm_gradient,
"LayerNorm fast backward for computing gamma and beta (CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/hip/HIPContext.h"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
//#include <cuda_bf16.h>
#include "util_hip.h"
template <int Dim_, int VecSize_, int BatchesPerBlock_, int WarpsForOneBatchPerBlock_>
struct LNParameters {
static constexpr int Dim = Dim_;
static constexpr int VecSize = VecSize_;
static constexpr int WarpSize = 32;
static constexpr int BatchesPerBlock = BatchesPerBlock_;
static constexpr int WarpStride = WarpSize * VecSize;
static constexpr int WarpsForOneBatchPerBlock = WarpsForOneBatchPerBlock_;
static constexpr int Iterations = Dim / WarpStride / WarpsForOneBatchPerBlock;
static constexpr int BatchStride = Dim / WarpsForOneBatchPerBlock;
static constexpr int ThreadsPerBlock = BatchesPerBlock * WarpSize * WarpsForOneBatchPerBlock;
static_assert(Dim == WarpsForOneBatchPerBlock * WarpStride * Iterations, "");
static_assert(Dim == BatchStride * WarpsForOneBatchPerBlock, "");
};
template <typename IndexType, typename input_t, typename output_t, typename acc_t, typename Parameters>
__global__ void layernorm_forward(output_t *dst, const input_t *src, const input_t *gamma, const input_t *beta,
acc_t *mean, acc_t *invvar, IndexType bsz, acc_t epsilon) {
static_assert(Parameters::WarpsForOneBatchPerBlock == 1, "");
IndexType batch = blockIdx.x * Parameters::BatchesPerBlock + threadIdx.y;
if (batch < bsz) {
src += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
dst += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
gamma += threadIdx.x * Parameters::VecSize;
beta += threadIdx.x * Parameters::VecSize;
using VecInType = VecType<input_t, Parameters::VecSize>;
VecInType elements[Parameters::Iterations];
VecInType gamma_reg[Parameters::Iterations];
VecInType beta_reg[Parameters::Iterations];
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
elements[i] = *(VecInType *)(src + i * Parameters::WarpStride);
gamma_reg[i] = *(VecInType *)(gamma + i * Parameters::WarpStride);
beta_reg[i] = *(VecInType *)(beta + i * Parameters::WarpStride);
}
input_t *elements_l = (input_t *)elements;
input_t *gamma_l = (input_t *)gamma_reg;
input_t *beta_l = (input_t *)beta_reg;
acc_t sum = 0.0;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
sum += (acc_t)elements_l[i];
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
sum += SHFL_XOR(sum, offset, Parameters::WarpSize);
}
acc_t mu = sum / Parameters::Dim;
acc_t var = 0.0;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
acc_t diff = (acc_t)elements_l[i] - mu;
var += diff * diff;
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
var += SHFL_XOR(var, offset, Parameters::WarpSize);
}
const acc_t rsigma = rsqrtf(var / Parameters::Dim + epsilon);
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
elements_l[i] = (input_t)(((acc_t)elements_l[i] - mu) * rsigma) * gamma_l[i] + beta_l[i];
}
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
*(VecInType *)(dst + i * Parameters::WarpStride) = elements[i];
}
if (threadIdx.x == 0) {
mean[batch] = mu;
invvar[batch] = rsigma;
}
}
}
template <typename IndexType, typename input_t, typename output_t, typename acc_t, typename Parameters>
__global__ void layernorm_backward_x(output_t *dst, const input_t *input, const input_t *grad, const input_t *gamma,
const acc_t *mean, const acc_t *invvar, IndexType bsz) {
IndexType batch = blockIdx.x * Parameters::BatchesPerBlock + threadIdx.y;
if (batch < bsz) {
input += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
dst += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
grad += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
gamma += threadIdx.x * Parameters::VecSize;
using VecInType = VecType<input_t, Parameters::VecSize>;
VecInType elements[Parameters::Iterations];
VecInType grad_reg[Parameters::Iterations];
VecInType gamma_reg[Parameters::Iterations];
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
elements[i] = *(VecInType *)(input + i * Parameters::WarpStride);
grad_reg[i] = *(VecInType *)(grad + i * Parameters::WarpStride);
gamma_reg[i] = *(VecInType *)(gamma + i * Parameters::WarpStride);
}
input_t *elements_l = (input_t *)elements;
input_t *grad_l = (input_t *)grad_reg;
input_t *gamma_l = (input_t *)gamma_reg;
const acc_t mu = mean[batch];
const acc_t var = invvar[batch];
acc_t sum1 = 0.0, sum2 = 0.0;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
elements_l[i] = elements_l[i] - (input_t)mu;
sum1 += (acc_t)(elements_l[i] * grad_l[i] * gamma_l[i]);
sum2 += (acc_t)(grad_l[i] * gamma_l[i]);
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
sum1 += SHFL_XOR(sum1, offset, Parameters::WarpSize);
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
sum2 += SHFL_XOR(sum2, offset, Parameters::WarpSize);
}
sum1 *= var * var * var / Parameters::Dim;
sum2 *= var / Parameters::Dim;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
elements_l[i] = grad_l[i] * gamma_l[i] * (input_t)var - (input_t)sum1 * elements_l[i] - (input_t)sum2;
}
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
*(VecInType *)(dst + i * Parameters::WarpStride) = elements[i];
}
}
}
#define LAUNCH_FORWARD_KERNEL(len, vec, batches, type) \
{ \
dim3 threads(32, batches); \
int blocks = DIV_CELL(n1, batches); \
hipLaunchKernelGGL(( layernorm_forward<size_t, type, type, float, LNParameters<len, vec, batches, 1>>) \
, dim3(blocks), dim3(threads), 0, stream, \
(type *)output->data_ptr(), (type *)input->data_ptr(), (type *)gamma->data_ptr(), \
(type *)beta->data_ptr(), (float *)mean->data_ptr(), (float *)invvar->data_ptr(), n1, epsilon); \
break; \
}
#define LAUNCH_BACKWARD_KERNEL(len, vec, batches, type) \
{ \
dim3 threads(32, batches); \
int blocks = DIV_CELL(n1, batches); \
hipLaunchKernelGGL(( layernorm_backward_x<size_t, type, type, float, LNParameters<len, vec, batches, 1>>) \
, dim3(blocks), dim3(threads), 0, stream, \
(type *)grad_input->data_ptr(), (type *)input->data_ptr(), (type *)dout->data_ptr(), \
(type *)gamma->data_ptr(), (float *)mean->data_ptr(), (float *)invvar->data_ptr(), n1); \
break; \
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto type = input->scalar_type();
// if (type == at::ScalarType::BFloat16) {
// switch (n2) {
// case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, nv_bfloat16)
// case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, nv_bfloat16)
// case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, nv_bfloat16)
// case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, nv_bfloat16)
// case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, nv_bfloat16)
// case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, nv_bfloat16)
// case 640: LAUNCH_FORWARD_KERNEL(640, 2, 4, nv_bfloat16)
// case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, nv_bfloat16)
// case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, nv_bfloat16)
// case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, nv_bfloat16)
// case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, nv_bfloat16)
// case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, nv_bfloat16)
// case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, nv_bfloat16)
// case 2560: LAUNCH_FORWARD_KERNEL(2560, 2, 4, nv_bfloat16)
// case 5120: LAUNCH_FORWARD_KERNEL(5120, 2, 4, nv_bfloat16)
// }
// } else if (type == at::ScalarType::Half) {
if (type == at::ScalarType::Half) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, half)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, half)
case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, half)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, half)
case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, half)
case 640: LAUNCH_FORWARD_KERNEL(640, 2, 4, half)
case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, half)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, half)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, half)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, half)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, half)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, half)
case 2560: LAUNCH_FORWARD_KERNEL(2560, 2, 4, half)
case 5120: LAUNCH_FORWARD_KERNEL(5120, 2, 4, half)
}
} else if (type == at::ScalarType::Float) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_FORWARD_KERNEL(128, 1, 4, float)
case 256: LAUNCH_FORWARD_KERNEL(256, 1, 4, float)
case 320: LAUNCH_FORWARD_KERNEL(320, 1, 4, float)
case 384: LAUNCH_FORWARD_KERNEL(384, 1, 4, float)
case 512: LAUNCH_FORWARD_KERNEL(512, 1, 4, float)
case 640: LAUNCH_FORWARD_KERNEL(640, 1, 4, float)
case 768: LAUNCH_FORWARD_KERNEL(768, 1, 4, float)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 1, 4, float)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 1, 4, float)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 1, 4, float)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 1, 4, float)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 1, 4, float)
case 2560: LAUNCH_FORWARD_KERNEL(2560, 1, 4, float)
case 5120: LAUNCH_FORWARD_KERNEL(5120, 1, 4, float)
}
}
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input)
{
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto type = input->scalar_type();
// if (type == at::ScalarType::BFloat16) {
// switch (n2) {
// case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, nv_bfloat16)
// case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, nv_bfloat16)
// case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, nv_bfloat16)
// case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, nv_bfloat16)
// case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, nv_bfloat16)
// case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, nv_bfloat16)
// case 640: LAUNCH_BACKWARD_KERNEL(640, 2, 4, nv_bfloat16)
// case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, nv_bfloat16)
// case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, nv_bfloat16)
// case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, nv_bfloat16)
// case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, nv_bfloat16)
// case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, nv_bfloat16)
// case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, nv_bfloat16)
// case 2560: LAUNCH_BACKWARD_KERNEL(2560, 2, 4, nv_bfloat16)
// case 5120: LAUNCH_BACKWARD_KERNEL(5120, 2, 4, nv_bfloat16)
// }
// } else if (type == at::ScalarType::Half) {
if (type == at::ScalarType::Half) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, half)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, half)
case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, half)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, half)
case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, half)
case 640: LAUNCH_BACKWARD_KERNEL(640, 2, 4, half)
case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, half)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, half)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, half)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, half)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, half)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, half)
case 2560: LAUNCH_BACKWARD_KERNEL(2560, 2, 4, half)
case 5120: LAUNCH_BACKWARD_KERNEL(5120, 2, 4, half)
}
} else if (type == at::ScalarType::Float) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_BACKWARD_KERNEL(128, 1, 4, float)
case 256: LAUNCH_BACKWARD_KERNEL(256, 1, 4, float)
case 320: LAUNCH_BACKWARD_KERNEL(320, 1, 4, float)
case 384: LAUNCH_BACKWARD_KERNEL(384, 1, 4, float)
case 512: LAUNCH_BACKWARD_KERNEL(512, 1, 4, float)
case 640: LAUNCH_BACKWARD_KERNEL(640, 1, 4, float)
case 768: LAUNCH_BACKWARD_KERNEL(768, 1, 4, float)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 1, 4, float)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 1, 4, float)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 1, 4, float)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 1, 4, float)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 1, 4, float)
case 2560: LAUNCH_BACKWARD_KERNEL(2560, 1, 4, float)
case 5120: LAUNCH_BACKWARD_KERNEL(5120, 1, 4, float)
}
}
}
// !!! This is a file automatically generated by hipify!!!
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/hip/HIPContext.h"
#include <THH/THHDeviceUtils.cuh>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
//#include <cuda_bf16.h>
#include "type_shim_hip.h"
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
}
template<typename T, typename U> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1*n2+i2;
const int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
const int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1*n2+i2;
const int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U> __global__
void cuComputePartGradGammaBeta(
const T* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
const int row1 = threadIdx.y + k*blockDim.y;
const int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
const int row1 = threadIdx.y;
const int row2 = threadIdx.y + offset;
const int idx1 = row1*row_stride + threadIdx.x;
const int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
const int row1 = threadIdx.y;
const int row2 = threadIdx.y + 1;
const int idx1 = row1*row_stride + threadIdx.x;
const int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename T, typename U> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
T* grad_gamma,
T* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U>
void HostLayerNormGradient(
const T* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const T* gamma,
const T* beta,
double epsilon,
T* grad_gamma,
T* grad_beta
)
{
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype((input->scalar_type()==at::ScalarType::Half || input->scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
hipLaunchKernelGGL(( cuComputePartGradGammaBeta), dim3(blocks2), dim3(threads2), nshared2, stream,
dout,
input->data_ptr<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.data_ptr<U>(),
part_grad_beta.data_ptr<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
hipLaunchKernelGGL(( cuComputeGradGammaBeta), dim3(blocks3), dim3(threads3), nshared3, stream,
part_grad_gamma.data_ptr<U>(),
part_grad_beta.data_ptr<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BF16(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient(
dout->data_ptr<scalar_t_0>(),
mean->data_ptr<accscalar_t>(),
invvar->data_ptr<accscalar_t>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma->data_ptr<scalar_t_0>(),
beta->data_ptr<scalar_t_0>(),
epsilon,
grad_gamma->data_ptr<scalar_t_0>(),
grad_beta->data_ptr<scalar_t_0>());
)
}
#include <torch/extension.h>
at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
std::vector<std::vector<at::Tensor>> tensor_lists);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}
\ No newline at end of file
#include "hip/hip_runtime.h"
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/HIPContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <assert.h>
#include <iostream>
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
int start_tensor_this_launch;
};
template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size,
T tl,
U callable,
ArgTypes... args)
{
callable(chunk_size, tl, args...);
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size,
int chunk_size,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
auto ref_dtype = tensor_lists[0][0].scalar_type();
for (int l = 0; l < tensor_lists.size(); l++)
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].scalar_type() == ref_dtype, "A tensor was not the same dtype as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size,
tl,
callable,
args...);
AT_CUDA_CHECK(hipGetLastError());
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <assert.h>
#include <iostream>
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
int start_tensor_this_launch;
};
template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size,
T tl,
U callable,
ArgTypes... args)
{
callable(chunk_size, tl, args...);
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size,
int chunk_size,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
auto ref_dtype = tensor_lists[0][0].scalar_type();
for (int l = 0; l < tensor_lists.size(); l++)
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].scalar_type() == ref_dtype, "A tensor was not the same dtype as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(tensor_lists[0][0]));
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
hipLaunchKernelGGL(( multi_tensor_apply_kernel), dim3(loc_block_info), dim3(block_size), 0, stream,
chunk_size,
tl,
callable,
args...);
AT_CUDA_CHECK(hipGetLastError());
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <assert.h>
#include <iostream>
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
int start_tensor_this_launch;
};
template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size,
T tl,
U callable,
ArgTypes... args)
{
callable(chunk_size, tl, args...);
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size,
int chunk_size,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
auto ref_dtype = tensor_lists[0][0].scalar_type();
for (int l = 0; l < tensor_lists.size(); l++)
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].scalar_type() == ref_dtype, "A tensor was not the same dtype as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(tensor_lists[0][0]));
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
hipLaunchKernelGGL(( multi_tensor_apply_kernel), dim3(loc_block_info), dim3(block_size), 0, stream,
chunk_size,
tl,
callable,
args...);
AT_CUDA_CHECK(hipGetLastError());
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
//#include <cuda_bf16.h>
#include <assert.h>
#include "type_shim_hip.h"
#include "multi_tensor_apply_hip.cuh"
//#define BLOCK_SIZE 512
#define BLOCK_SIZE 256
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t>
struct L2NormFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<1>& tl,
float* output)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
__shared__ float s_vals[512];
float vals[ILP];
x_t r_x[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.0f;
r_x[i] = (x_t)0.0f;
}
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_x[ii]);
vals[ii] += next*next;
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
}
}
float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float res = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0)
{
output[blockIdx.x] += res;
}
}
};
__global__ void cleanup(
float* output,
float* ret)
{
__shared__ float vals[512];
if(blockIdx.x == 0)
{
float val = 0;
if(threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
}
at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
std::vector<std::vector<at::Tensor>> tensor_lists)
{
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
switch (tensor_lists[0][0].scalar_type()){
case at::ScalarType::Float: {
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
tensor_lists,
L2NormFunctor<float>(),
output.data_ptr<float>()
);
break;
}
case at::ScalarType::Half: {
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
tensor_lists,
L2NormFunctor<half>(),
output.data_ptr<float>()
);
break;
}
// case at::ScalarType::BFloat16: {
// multi_tensor_apply<1>(
// BLOCK_SIZE,
// chunk_size,
// tensor_lists,
// L2NormFunctor<bfloat16>(),
// output.data_ptr<float>()
// );
// break;
// }
}
AT_CUDA_CHECK(hipGetLastError());
auto ret = at::empty({1}, output.options());
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(output));
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( cleanup), dim3(1), dim3(512), 0, stream,
output.data_ptr<float>(),
ret.data_ptr<float>());
return ret;
}
// !!! This is a file automatically generated by hipify!!!
#include <vector>
#include <ATen/ATen.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include <c10/hip/HIPMathCompat.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <cuda_bf16.h>
#include <hiprand/hiprand_kernel.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <math.h>
#include <iostream>
union float_int_32
{
uint32_t i;
float f;
};
__global__ void fp32_to_bf16(
const float* input,
float* output,
const int tsize,
uint64_t seed,
uint64_t offset) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < tsize) {
float_int_32 d;
d.f = input[i];
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, i, offset, &state);
d.i += hiprand(&state) & 0x0000ffff;
output[i] = (d.f);
}
}
void fused_fp32_to_bf16_sr_cuda(
at::Tensor & input,
at::Tensor & output)
{
int tsize = input.numel();
const int threadsPerBlock = 512;
const int blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(input), "parameter tensor is too large to be indexed with int32");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Float, "expected input to be float32 tensor");
AT_ASSERTM(output.scalar_type() == at::ScalarType::BFloat16, "expected output to be bfloat16 tensor");
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(1);
}
uint64_t seed = std::get<0>(rng_engine_inputs);
uint64_t offset = std::get<1>(rng_engine_inputs);
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( fp32_to_bf16), dim3(blocks), dim3(threadsPerBlock), 0, stream,
(const float*)input.data_ptr(),
(float*)output.data_ptr(),
tsize,
seed,
offset);
AT_CUDA_CHECK(hipGetLastError());
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <cuda_bf16.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
void fused_fp32_to_bf16_sr_cuda(at::Tensor & input, at::Tensor & output);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void fused_fp32_to_bf16_sr(at::Tensor & input, at::Tensor & output) {
CHECK_INPUT(input);
CHECK_INPUT(output);
int64_t num_elem = input.numel();
AT_ASSERTM(output.numel() == num_elem, "number of elements in input ond output tensors should be equal");
fused_fp32_to_bf16_sr_cuda(input, output);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp32_to_bf16_sr", &fused_fp32_to_bf16_sr, "fused fp32 to bf16 random rounding");
}
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <ATen/Generator.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <vector>
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_);
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<c10::optional<torch::Tensor>> fwd(
bool is_training,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_)
{
CHECK_INPUT(input);
if (attn_mask)
{
CHECK_INPUT(attn_mask.value());
AT_ASSERTM(attn_mask->dim() == 3, "expected 3D tensor");
}
if (bias)
{
CHECK_INPUT(bias.value());
AT_ASSERTM(bias->dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.size(0) % bias->size(0) == 0, "wrong first dim of bias.");
AT_ASSERTM(bias->size(1) == input.size(1) && bias->size(2) == input.size(2), "the last two dims of bias and input should be the same.");
}
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16 ||
input.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
return fwd_cuda(is_training, input, attn_mask, bias, dropout_prob, gen_);
}
torch::Tensor bwd(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob)
{
CHECK_INPUT(output_grads);
CHECK_INPUT(softmax_results);
if (dropout_mask)
{
CHECK_INPUT(dropout_mask.value());
}
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(!dropout_mask || dropout_mask->dim() == 1, "expected 1D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half ||
output_grads.scalar_type() == at::ScalarType::BFloat16 ||
output_grads.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half ||
softmax_results.scalar_type() == at::ScalarType::BFloat16 ||
softmax_results.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
AT_ASSERTM(output_grads.scalar_type() == softmax_results.scalar_type(), "the types mismatch");
return bwd_cuda(output_grads, softmax_results, dropout_mask, dropout_prob);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &fwd, "softmax dropout -- Forward.");
m.def("backward", &bwd, "softmax dropout -- Backward.");
}
// !!! This is a file automatically generated by hipify!!!
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <cuda_bf16.h>
//#include <cuda_profiler_api.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <math.h>
#include "type_shim_hip.h"
#include "softmax_fast_hip.h"
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_)
{
const int64_t attn_batches = input.size(0);
const int q_seq_len = input.size(1);
const int k_seq_len = input.size(2);
void *bias_ptr = nullptr;
int64_t bias_batches = 0;
if (bias)
{
bias_ptr = reinterpret_cast<void *>(bias->data_ptr());
bias_batches = bias->size(0);
}
void *attn_mask_prt = nullptr;
int64_t mask_inner_skip = 0;
if (attn_mask)
{
attn_mask_prt = reinterpret_cast<void *>(attn_mask->data_ptr());
mask_inner_skip = static_cast<int64_t>(attn_batches / attn_mask->size(0) * q_seq_len / attn_mask->size(1));
}
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(softmax_mask_dtype(k_seq_len));
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = reinterpret_cast<void *>(input.data_ptr());
void *softmax_results_ptr = reinterpret_cast<void *>(input.data_ptr());
// Padded Softmax
bool softmax_success = false;
auto scalar_type = input.scalar_type();
if (is_training && dropout_prob > 0.0f)
{
torch::Tensor dropout_results = torch::empty({static_cast<int64_t>(attn_batches), q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty(
{softmax_mask_size(static_cast<int64_t>(attn_batches * q_seq_len), k_seq_len)}, mask_options);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(softmax_rng_delta_offset(k_seq_len));
}
uint64_t seed = std::get<0>(rng_engine_inputs);
uint64_t offset = std::get<1>(rng_engine_inputs);
if (bias)
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, true, true>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, true, false>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
}
else
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, false, true>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
nullptr,
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, true, false, false>(
reinterpret_cast<scalar_t_0 *>(dropout_results.data_ptr()),
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
nullptr,
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
seed, offset);)
}
}
if (softmax_success)
{
return {dropout_results, dropout_mask, input};
}
else
{
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
}
else
{
if (bias)
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, true, true>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, true, false>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(bias_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
}
else
{
if (attn_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, false, true>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
reinterpret_cast<const scalar_t_0 *>(attn_mask_prt),
nullptr,
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_forward",
softmax_success = dispatch_softmax_forward<scalar_t_0, scalar_t_0, float, false, false, false>(
reinterpret_cast<scalar_t_0 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const scalar_t_0 *>(input_ptr),
nullptr,
nullptr,
nullptr,
1.0,
k_seq_len,
attn_batches * q_seq_len,
mask_inner_skip,
bias_batches * q_seq_len,
0, 0);)
}
}
if (softmax_success)
{
return {input, c10::optional<torch::Tensor>(), input};
}
else
{
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
}
}
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob)
{
const int64_t attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = output_grads.size(2);
auto scalar_type = output_grads.scalar_type();
if (dropout_mask)
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_backward",
dispatch_softmax_backward<scalar_t_0, scalar_t_0, float, false, true>(
reinterpret_cast<scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<const scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<const scalar_t_0 *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches * q_seq_len);)
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16(scalar_type, 0, "softmax_backward",
dispatch_softmax_backward<scalar_t_0, scalar_t_0, float, false, false>(
reinterpret_cast<scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<scalar_t_0 *>(output_grads.data_ptr()),
reinterpret_cast<const scalar_t_0 *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches * q_seq_len);)
}
// backward pass is completely in-place
return output_grads;
}
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <iostream>
#include <type_traits>
#include <limits>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <cuda_bf16.h>
#include <hiprand/hiprand_kernel.h>
#include <hipcub/hipcub.hpp>
#include "util_hip.h"
template <int N>
using IntegerBits = typename std::conditional<N <= 8, uint8_t,
typename std::conditional<N <= 16, uint16_t,
typename std::conditional<N <= 32, uint32_t,
typename std::conditional<N <= 64, uint64_t, void>::type>::type>::type>::type;
template <int LogElements>
struct SoftmaxParameters
{
static_assert(LogElements <= 11, "");
static constexpr int Elements = 1 << LogElements;
static constexpr int WarpBatch = Elements <= 128 ? 2 : 1;
static constexpr int WarpIterations = Elements <= 32 ? 1 : Elements / 32;
using MaskType = IntegerBits<WarpIterations>;
static constexpr int WarpSize = Elements <= 64 ? Elements : 64;
static constexpr int MaskStride = WarpSize;
};
inline int log2_ceil(int value)
{
int log2_value = 0;
while ((1 << log2_value) < value)
++log2_value;
return log2_value;
}
inline at::ScalarType softmax_mask_dtype(int elements)
{
if (elements > 1024)
{
return torch::kInt64;
}
else if (elements > 512)
{
return torch::kInt32;
}
else if (elements > 256)
{
return torch::kInt16;
}
return torch::kInt8;
}
inline int softmax_mask_size(int batch_size, int elements)
{
int log2_elements = log2_ceil(elements);
int e = 1 << log2_elements;
int warp_size = e < 32 ? e : 32;
return batch_size * warp_size;
}
inline int softmax_rng_delta_offset(int elements)
{
int log2_elements = log2_ceil(elements);
int e = 1 << log2_elements;
int warp_iterations = e <= 32 ? 1 : e / 32;
int warp_batch = e <= 128 ? 2 : 1;
return warp_iterations * warp_batch;
}
inline hipError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
int *num_blocks) {
int dev;
{
hipError_t err = hipGetDevice(&dev);
if (err != hipSuccess) {
return err;
}
}
int sm_count;
{
hipError_t err = hipDeviceGetAttribute(&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
if (err != hipSuccess) {
return err;
}
}
int tpm;
{
hipError_t err = hipDeviceGetAttribute(&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
if (err != hipSuccess) {
return err;
}
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
return hipSuccess;
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef hipcub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
// modified from https://github.com/Oneflow-Inc/oneflow/blob/5d74efa4d07adfd0acbc8e0074778687f1006b86/oneflow/core/cuda/softmax.cuh#L480-L529
// Copyright 2020 The OneFlow Authors. All rights reserved.
template <typename input_t, typename output_t, typename acc_t, int block_size, bool NeedBias, bool NeedAttnMask>
__global__ void softmax_block_forward(const input_t *input, output_t *output, const input_t *attn_mask, const input_t *bias,
int64_t rows, int cols, int64_t attn_inner_skip_batch, int64_t bias_batch_size) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<acc_t *>(shared_buf);
const int tid = threadIdx.x;
auto element_count = cols;
int64_t bias_mod_size = bias_batch_size * cols;
int64_t attn_mask_div_size = element_count;
if IF_CONSTEXPR (NeedAttnMask)
{
attn_mask_div_size = attn_inner_skip_batch * element_count;
}
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
acc_t thread_max = -std::numeric_limits<acc_t>::infinity();
int64_t idx_offset = row * cols;
const input_t* input_ptr = input + idx_offset;
output_t* output_ptr = output + idx_offset;
const input_t* attn_mask_ptr = nullptr;
if IF_CONSTEXPR (NeedAttnMask){
attn_mask_ptr = attn_mask + static_cast<int64_t>(idx_offset / attn_mask_div_size) * element_count ;
}
const input_t* bias_ptr = nullptr;
if IF_CONSTEXPR (NeedBias) {
bias_ptr = bias + idx_offset % bias_mod_size;
}
// TODO: enable pack as oneflow
for (int col = tid; col < cols; col += block_size) {
buf[col] = static_cast<acc_t>(input_ptr[col]);
if IF_CONSTEXPR (NeedAttnMask)
{
buf[col] += attn_mask_ptr[col];
}
if IF_CONSTEXPR (NeedBias)
{
buf[col] += bias_ptr[col];
}
thread_max = max(thread_max, buf[col]);
}
const acc_t row_max = BlockAllReduce<MaxOp, acc_t, block_size>(thread_max);
acc_t thread_sum = 0;
for (int col = tid; col < cols; col += block_size) {
buf[col] = std::exp(buf[col] - row_max);
thread_sum += buf[col];
}
const acc_t row_sum = BlockAllReduce<SumOp, acc_t, block_size>(thread_sum);
for (int col = tid; col < cols; col += block_size) {
output_ptr[col] = static_cast<output_t>(buf[col] / row_sum);
}
}
}
template<typename input_t, typename output_t, typename acc_t, int block_size>
__global__ void softmax_block_backward(output_t* store, const input_t* dy, const input_t* y,
const int64_t rows, const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* dy_buf = reinterpret_cast<acc_t*>(grad_shared_buf);
auto* y_buf = reinterpret_cast<input_t*>(dy_buf + cols);
const int tid = threadIdx.x;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
acc_t thread_sum = 0;
auto dy_ptr = dy + row * cols;
auto y_ptr = y + row * cols;
auto store_ptr = store + row * cols;
for (int col = tid; col < cols; col += block_size) {
y_buf[col] = y_ptr[col];
dy_buf[col] = dy_ptr[col] * (acc_t)y_ptr[col];
}
for (int col = tid; col < cols; col += block_size) {
thread_sum += dy_buf[col];
}
const acc_t row_sum = BlockAllReduce<SumOp, acc_t, block_size>(thread_sum);
for (int col = tid; col < cols; col += block_size) {
store_ptr[col] = static_cast<output_t>(dy_buf[col] - y_buf[col] * row_sum);
}
}
}
template <
typename input_t, typename output_t, typename acc_t,
typename Parameters, bool NeedMask, bool NeedBias, bool NeedAttnMask>
__global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const output_t *src, const input_t *attn_mask, const input_t *bias,
typename Parameters::MaskType *mask, acc_t p, int64_t batch_size, int64_t attn_inner_skip_batch, int64_t bias_batch_size, int element_count, uint64_t seed, uint64_t rand_offset)
{
using MaskType = typename Parameters::MaskType;
hiprandStatePhilox4_32_10_t state;
int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
// there might be multiple batches per warp. compute the index within the batch
int64_t local_idx = threadIdx.x;
const int64_t thread_offset = first_batch * element_count + local_idx;
if IF_CONSTEXPR (NeedMask)
{
hiprand_init(seed, thread_offset, rand_offset, &state);
}
// batch_size might not be a multiple of Parameters::WarpBatch. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > Parameters::WarpBatch)
local_batches = Parameters::WarpBatch;
src += thread_offset;
dst += thread_offset;
if IF_CONSTEXPR (NeedMask)
{
dst_orig += thread_offset;
mask += first_batch * Parameters::MaskStride;
}
int64_t bias_mod_size = bias_batch_size * element_count;
int64_t attn_mask_div_size = element_count;
if IF_CONSTEXPR (NeedAttnMask)
{
attn_mask_div_size = attn_inner_skip_batch * element_count;
}
// load data from global memory
input_t elements_input[Parameters::WarpBatch][Parameters::WarpIterations];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
auto src_ptr = src + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
elements_input[i][it] = -std::numeric_limits<float>::infinity();
if (element_index < batch_element_count)
{
elements_input[i][it] = src_ptr[it * Parameters::WarpSize];
}
}
}
// convert input_t to acc_t
acc_t elements[Parameters::WarpBatch][Parameters::WarpIterations];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
int64_t idx_offset = (first_batch + i) * element_count;
const input_t* attn_mask_ptr = nullptr;
if IF_CONSTEXPR (NeedAttnMask){
attn_mask_ptr = attn_mask + static_cast<int64_t>(idx_offset / attn_mask_div_size) * element_count + local_idx;
}
const input_t* bias_ptr = nullptr;
if IF_CONSTEXPR (NeedBias){
bias_ptr = bias + idx_offset % bias_mod_size + local_idx;
}
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
elements[i][it] = elements_input[i][it];
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count)
{
if IF_CONSTEXPR (NeedAttnMask)
{
elements[i][it] += attn_mask_ptr[it * Parameters::WarpSize];
}
if IF_CONSTEXPR (NeedBias)
{
elements[i][it] += bias_ptr[it * Parameters::WarpSize];
}
}
}
}
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it)
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
{
float val[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
val[i] = SHFL_XOR(max_value[i], offset, Parameters::WarpSize);
}
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[Parameters::WarpBatch]{0.0f};
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
}
}
// store result
if IF_CONSTEXPR (NeedMask)
{
const acc_t pinv = 1.0 / p;
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
MaskType m = 0;
if IF_CONSTEXPR (Parameters::WarpIterations == 1)
{
float rand = hiprand_uniform(&state);
m = rand < p;
}
else if IF_CONSTEXPR (Parameters::WarpIterations == 2)
{
m = hiprand_uniform(&state) < p;
m |= (hiprand_uniform(&state) < p) << 1;
}
else
{
#pragma unroll
for (int j = 0; j < DIV_CELL(Parameters::WarpIterations, 4); ++j)
{
float4 rand4 = hiprand_uniform4(&state);
m |= (((MaskType)(rand4.x < p)) << (j * 4)) | (((MaskType)(rand4.y < p)) << (j * 4 + 1)) | (((MaskType)(rand4.z < p)) << (j * 4 + 2)) | (((MaskType)(rand4.w < p)) << (j * 4 + 3));
}
}
mask[i * Parameters::MaskStride + local_idx] = m;
auto dst_ptr = dst + i * element_count;
auto dst_orig_ptr = dst_orig + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count)
{
const output_t d = elements[i][it] / sum[i];
dst_ptr[it * Parameters::WarpSize] = (acc_t)d * ((acc_t)((m >> it) & 1) * pinv);
dst_orig_ptr[it * Parameters::WarpSize] = d;
}
else
{
break;
}
}
}
}
else
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
auto dst_ptr = dst + i * element_count;
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count)
{
dst_ptr[it * Parameters::WarpSize] = elements[i][it] / sum[i];
}
else
{
break;
}
}
}
}
}
#define LAUNCH_FORWARD_KERNEL(l) \
hipLaunchKernelGGL(( softmax_warp_forward<input_t, output_t, acc_t, SoftmaxParameters<l>, NeedMask, NeedBias, NeedAttnMask>) \
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), \
dst, dst_orig, src, attn_mask, bias, (typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, attn_inner_skip_batch, bias_batch_count, softmax_elements, seed, offset); \
return true;
template <typename input_t, typename output_t, typename acc_t, bool NeedMask, bool NeedBias, bool NeedAttnMask>
bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *src, const input_t *attn_mask, const input_t *bias, void *mask, acc_t p,
int softmax_elements, int64_t batch_count, int64_t attn_inner_skip_batch, int64_t bias_batch_count, uint64_t seed, uint64_t offset)
{
if (softmax_elements == 0)
{
return false;
}
else
{
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the Parameters::WarpSize constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// This value must match the Parameters::WarpBatch constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements)
{
case 0:
LAUNCH_FORWARD_KERNEL(0)
case 1:
LAUNCH_FORWARD_KERNEL(1)
case 2:
LAUNCH_FORWARD_KERNEL(2)
case 3:
LAUNCH_FORWARD_KERNEL(3)
case 4:
LAUNCH_FORWARD_KERNEL(4)
case 5:
LAUNCH_FORWARD_KERNEL(5)
case 6:
LAUNCH_FORWARD_KERNEL(6)
case 7:
LAUNCH_FORWARD_KERNEL(7)
case 8:
LAUNCH_FORWARD_KERNEL(8)
case 9:
LAUNCH_FORWARD_KERNEL(9)
case 10:
LAUNCH_FORWARD_KERNEL(10)
default:
{
int grid_dim;
constexpr int block_size = 128;
constexpr int waves = 32;
auto cols = softmax_elements;
auto rows = batch_count;
GetNumBlocks(block_size, rows, waves, &grid_dim);
dim3 block(block_size);
const size_t smem = cols * sizeof(acc_t);
hipLaunchKernelGGL(( softmax_block_forward<input_t, output_t, acc_t, block_size, NeedBias, NeedAttnMask>), dim3(grid_dim), dim3(block), smem, 0,
src, dst, attn_mask, bias, rows, cols, attn_inner_skip_batch, bias_batch_count);
return true;
}
}
}
return false;
}
template <
typename input_t, typename output_t, typename acc_t, typename Parameters,
bool IsLogSoftmax, bool NeedMask>
__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output,
const typename Parameters::MaskType *mask, acc_t p, int64_t batch_size, int element_count)
{
using MaskType = typename Parameters::MaskType;
int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
// batch_size might not be a multiple of Parameters::WarpBatch. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > Parameters::WarpBatch)
local_batches = Parameters::WarpBatch;
// there might be multiple batches per warp. compute the index within the batch
int64_t local_idx = threadIdx.x;
// the first element to process by the current thread
int64_t thread_offset = first_batch * element_count + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
if IF_CONSTEXPR (NeedMask)
{
mask += first_batch * Parameters::MaskStride;
}
// The nested loops over Parameters::WarpBatch and then Parameters::WarpIterations can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[Parameters::WarpBatch][Parameters::WarpIterations];
input_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations];
if IF_CONSTEXPR (NeedMask)
{
MaskType mask_reg[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
mask_reg[i] = mask[i * Parameters::MaskStride + local_idx];
}
const acc_t pinv = 1.0 / p;
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
MaskType m = mask_reg[i];
auto output_ptr = output + i * element_count;
auto grad_ptr = grad + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count)
{
grad_reg[i][it] =
(acc_t)((m >> it) & 1) *
(acc_t)grad_ptr[it * Parameters::WarpSize] *
pinv *
output_ptr[it * Parameters::WarpSize];
output_reg[i][it] = output_ptr[it * Parameters::WarpSize];
}
else
{
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = input_t(0);
}
}
}
}
else
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
int batch_element_count = (i >= local_batches) ? 0 : element_count;
auto output_ptr = output + i * element_count;
auto grad_ptr = grad + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count)
{
output_reg[i][it] = output_ptr[it * Parameters::WarpSize];
grad_reg[i][it] = grad_ptr[it * Parameters::WarpSize] *
(acc_t)output_ptr[it * Parameters::WarpSize];
}
else
{
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = output_t(0);
}
}
}
}
acc_t sum[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it)
{
sum[i] += grad_reg[i][it];
}
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
{
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
}
}
// store result
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i)
{
if (i >= local_batches)
break;
auto gradInput_ptr = gradInput + i * element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it)
{
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count)
{
// compute gradients
if IF_CONSTEXPR (IsLogSoftmax)
{
gradInput_ptr[it * Parameters::WarpSize] =
(grad_reg[i][it] - std::exp((acc_t)output_reg[i][it]) * sum[i]);
}
else
{
gradInput_ptr[it * Parameters::WarpSize] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
#define LAUNCH_BACKWARD_KERNEL(l) \
hipLaunchKernelGGL(( softmax_warp_backward<input_t, output_t, acc_t, SoftmaxParameters<l>, IsLogSoftmax, NeedMask>) \
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), \
grad_input, grad, output, (const typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, softmax_elements); \
break;
template <typename input_t, typename output_t, typename acc_t, bool IsLogSoftmax, bool NeedMask>
void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output,
const void *mask, acc_t p, int softmax_elements, int64_t batch_count)
{
if (softmax_elements == 0)
{
return;
}
else
{
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements)
{
case 0:
LAUNCH_BACKWARD_KERNEL(0)
case 1:
LAUNCH_BACKWARD_KERNEL(1)
case 2:
LAUNCH_BACKWARD_KERNEL(2)
case 3:
LAUNCH_BACKWARD_KERNEL(3)
case 4:
LAUNCH_BACKWARD_KERNEL(4)
case 5:
LAUNCH_BACKWARD_KERNEL(5)
case 6:
LAUNCH_BACKWARD_KERNEL(6)
case 7:
LAUNCH_BACKWARD_KERNEL(7)
case 8:
LAUNCH_BACKWARD_KERNEL(8)
case 9:
LAUNCH_BACKWARD_KERNEL(9)
case 10:
LAUNCH_BACKWARD_KERNEL(10)
default:
{
int grid_dim;
constexpr int block_size = 128;
constexpr int waves = 32;
auto cols = softmax_elements;
auto rows = batch_count;
GetNumBlocks(block_size, rows, waves, &grid_dim);
dim3 block(block_size);
const size_t smem = cols * sizeof(acc_t) + cols * sizeof(input_t) ;
hipLaunchKernelGGL(( softmax_block_backward<input_t, output_t, acc_t, block_size>), dim3(grid_dim), dim3(block), smem, 0,
grad_input, grad, output, rows, cols);
}
}
}
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down(final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down( final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
// !!! This is a file automatically generated by hipify!!!
#pragma once
#define DIV_CELL(a, b) (((a) + (b) - 1) / (b))
#if __cplusplus >= 201703L
#define IF_CONSTEXPR constexpr
#else
#define IF_CONSTEXPR
#endif
template <typename T>
__device__ __forceinline__ T SHFL_XOR(T value, int laneMask, int width, unsigned int mask = 0xffffffff)
{
#if TORCH_HIP_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename T, int N>
struct VecTypeImpl;
#define DEFINE_VEC_TYPE(t, n, tn) \
template <> \
struct VecTypeImpl<t, n> { \
using type = tn; \
};
DEFINE_VEC_TYPE(half, 1, half)
//DEFINE_VEC_TYPE(__nv_bfloat16, 1, __nv_bfloat16)
DEFINE_VEC_TYPE(float, 1, float)
DEFINE_VEC_TYPE(half, 2, half2)
//DEFINE_VEC_TYPE(__nv_bfloat16, 2, __nv_bfloat162)
DEFINE_VEC_TYPE(float, 2, float2)
DEFINE_VEC_TYPE(half, 4, uint64_t)
//DEFINE_VEC_TYPE(__nv_bfloat16, 4, uint64_t)
DEFINE_VEC_TYPE(float, 4, float4)
template <typename T, int N>
using VecType = typename VecTypeImpl<T, N>::type;
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
ENV LANG C.UTF-8
ENV OFED_VERSION=5.3-1.0.0.1
RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \
rm -rf /var/lib/apt/lists/* \
/etc/apt/sources.list.d/cuda.list \
/etc/apt/sources.list.d/nvidia-ml.list && \
apt-get update && \
DEBIAN_FRONTEND=noninteractive $APT_INSTALL \
software-properties-common \
&& \
apt-get update && \
DEBIAN_FRONTEND=noninteractive $APT_INSTALL \
build-essential \
apt-utils \
ca-certificates \
wget \
git \
vim \
libssl-dev \
curl \
unzip \
unrar \
cmake \
net-tools \
sudo \
autotools-dev \
rsync \
jq \
openssh-server \
tmux \
screen \
htop \
pdsh \
openssh-client \
lshw \
dmidecode \
util-linux \
automake \
autoconf \
libtool \
net-tools \
pciutils \
libpci-dev \
libaio-dev \
libcap2 \
libtinfo5 \
fakeroot \
devscripts \
debhelper \
nfs-common
# ==================================================================
# InfiniBand & RDMA
# ------------------------------------------------------------------
RUN cd /tmp && \
wget -q http://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \
rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}*
RUN cd /tmp && \
mkdir -p /usr/local/nccl-rdma-sharp-plugins && \
DEBIAN_FRONTEND=noninteractive apt install -y zlib1g-dev && \
git clone --depth=1 https://github.com/Mellanox/nccl-rdma-sharp-plugins.git && \
cd nccl-rdma-sharp-plugins && \
./autogen.sh && \
./configure --prefix=/usr/local/nccl-rdma-sharp-plugins --with-cuda=/usr/local/cuda && \
make && \
make install
# ==================================================================
# python
# ------------------------------------------------------------------
# Set timezone
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
ENV PATH /usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV PYTHON_VERSION=3.8
RUN wget -O ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh
ENV PATH /opt/conda/bin:$PATH
RUN conda install -y python=3.8 && conda clean -ya
RUN conda install -y scipy scikit-learn pyyaml tensorboard tensorboardX && \
conda clean -ya
RUN ldconfig
# ==================================================================
# pytorch
# ------------------------------------------------------------------
ENV TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0"
RUN conda install -y numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas && \
conda clean -ya
RUN conda install pytorch=1.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge && \
conda clean -ya
RUN cd /tmp && \
git clone https://github.com/dptech-corp/Uni-Core && \
cd Uni-Core && \
python setup.py install &&\
rm -rf /tmp/*
RUN pip install --no-cache-dir tokenizers lmdb biopython ml-collections timeout-decorator urllib3 tree dm-tree
ENV LD_LIBRARY_PATH=/usr/local/nccl-rdma-sharp-plugins/lib:$LD_LIBRARY_PATH
ENV PATH=/usr/mpi/gcc/openmpi-4.1.0rc5/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/mpi/gcc/openmpi-4.1.0rc5/lib:$LD_LIBRARY_PATH
RUN ldconfig && \
apt-get clean && \
apt-get autoremove && \
rm -rf /var/lib/apt/lists/* /tmp/* && \
conda clean -ya
\ No newline at end of file
FROM nvcr.io/nvidia/pytorch:22.04-py3
RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \
rm -rf /var/lib/apt/lists/* \
/etc/apt/sources.list.d/cuda.list \
/etc/apt/sources.list.d/nvidia-ml.list && \
apt-get update && \
DEBIAN_FRONTEND=noninteractive $APT_INSTALL \
software-properties-common \
&& \
apt-get update && \
DEBIAN_FRONTEND=noninteractive $APT_INSTALL \
build-essential \
apt-utils \
ca-certificates \
wget \
git \
vim \
libssl-dev \
curl \
unzip \
unrar \
cmake \
net-tools \
sudo \
autotools-dev \
rsync \
jq \
openssh-server \
tmux \
screen \
htop \
pdsh \
openssh-client \
lshw \
dmidecode \
util-linux \
automake \
autoconf \
libtool \
net-tools \
pciutils \
libpci-dev \
libaio-dev \
libcap2 \
libtinfo5 \
fakeroot \
devscripts \
debhelper \
nfs-common
RUN pip uninstall -y torch torchvision torchtext && \
pip uninstall -y torch torchvision torchtext && \
rm -rf ~/.cache/pip && \
conda clean -ya
RUN conda install -y pyyaml tensorboardX && \
conda clean -ya
# RUN ldconfig
# # ==================================================================
# # pytorch
# # ------------------------------------------------------------------
ENV TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0"
RUN conda install -y ninja typing && \
conda clean -ya
RUN pip3 install --no-cache-dir torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 && rm -rf ~/.cache/pip
RUN cd /tmp && \
git clone https://github.com/dptech-corp/Uni-Core && \
cd Uni-Core && \
python setup.py install && \
rm -rf /tmp/* && rm -rf ~/.cache/pip
RUN pip3 install --no-cache-dir tokenizers lmdb biopython ml-collections timeout-decorator urllib3 tree dm-tree && rm -rf ~/.cache/pip
RUN ldconfig && \
apt-get clean && \
apt-get autoremove && \
rm -rf /var/lib/apt/lists/* /tmp/* && \
conda clean -ya
import bert.task
import bert.model
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment