Commit f206de08 authored by shenggan's avatar shenggan
Browse files

refactor kernel implementation

parent 1c0a3d39
......@@ -79,10 +79,6 @@ If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfol
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256 --openfold
```
## Acknowledge
The CUDA implementations of the LayerNorm and Softmax are modified from [OneFlow](https://github.com/Oneflow-Inc/oneflow). Thanks to OneFlow for the high performance CUDA implementation, we mainly add support of Bfloat16 precision.
## Cite us
Cite this paper, if you use FastFold in your research publication.
......
This diff is collapsed.
// part of code modified from https://github.com/NVIDIA/apex
#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "layer_norm.cuh"
#include "type_shim.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
......@@ -12,19 +19,175 @@
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
inline __device__ void WelfordOnline(float val, float* mean, float* m2, float* count) {
*count += 1;
float delta1 = val - *mean;
*mean += delta1 / (*count);
float delta2 = val - *mean;
*m2 += delta1 * delta2;
}
inline __device__ void WelfordOnline(float b_mean, float b_m2, float b_count, float* mean,
float* m2, float* count) {
if (b_count == 0) {
return;
}
float new_count = *count + b_count;
float nb_n = b_count / new_count;
float delta = b_mean - *mean;
*mean += delta * nb_n;
*m2 += b_m2 + delta * delta * (*count) * nb_n;
*count = new_count;
}
__inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_m2,
float thread_count, float* mean, float* m2,
float* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = 1; mask < 32; mask *= 2) {
float b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
float b_count = __shfl_down_sync(0xffffffff, *count, mask);
WelfordOnline(b_mean, b_m2, b_count, mean, m2, count);
}
*mean = __shfl_sync(0xffffffff, *mean, 0, 32);
*m2 = __shfl_sync(0xffffffff, *m2, 0, 32);
*count = __shfl_sync(0xffffffff, *count, 0, 32);
}
__global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamma, float* beta,
float* mean, float* invvar, int rows, int cols,
double epsilon) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int lane_id = threadidx_y;
float buf[32];
float thread_mean;
float thread_m2;
float thread_count;
float warp_mean;
float warp_m2;
float warp_count;
float* row_input = input + row_offset * cols;
float* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f);
float row_inv_var = rsqrt(row_variance + epsilon);
if (lane_id == 0) {
mean[row_offset] = row_mean;
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
buf[i] * gamma[lane_id * cols_per_thread + i] + beta[lane_id * cols_per_thread + i];
}
}
__global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* output,
at::BFloat16* gamma, at::BFloat16* beta, float* mean,
float* invvar, int rows, int cols, double epsilon) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int lane_id = threadidx_y;
float buf[32];
float thread_mean;
float thread_m2;
float thread_count;
float warp_mean;
float warp_m2;
float warp_count;
at::BFloat16* row_input = input + row_offset * cols;
at::BFloat16* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f);
float row_inv_var = rsqrt(row_variance + epsilon);
if (lane_id == 0) {
mean[row_offset] = row_mean;
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(buf[i]) * gamma[lane_id * cols_per_thread + i] +
beta[lane_id * cols_per_thread + i];
}
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
int rows, int cols, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon) {
at::Tensor normalized = at::empty_like(*output);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::AffineStore<float, at::BFloat16, true, true> store(
(at::BFloat16*)normalized.data_ptr(), (at::BFloat16*)output->data_ptr(), n2,
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr());
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), float>(
cuda_stream, load, store, n1, n2, epsilon, (float*)mean->data_ptr(),
(float*)invvar->data_ptr());
int grid = rows / 4;
dim3 block(128);
if (output->dtype() == torch::kFloat32) {
fastfold_layernorm_fp32<<<grid, block>>>(
(float*)input->data_ptr(), (float*)output->data_ptr(), (float*)gamma->data_ptr(),
(float*)beta->data_ptr(), (float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows,
cols, epsilon);
} else {
fastfold_layernorm_bfp16<<<grid, block>>>(
(at::BFloat16*)input->data_ptr(), (at::BFloat16*)output->data_ptr(),
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr(),
(float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows, cols, epsilon);
}
}
template <typename T>
......@@ -208,6 +371,116 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_g
}
}
template <typename T, typename U, typename V>
__global__ void cuComputeGradInput(const V* __restrict__ dout, const T* __restrict__ input,
const int n1, const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon, const V* gamma,
T* grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1 * n2;
const V* k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1,
int n2, const V* gamma, const V* beta, double epsilon, T* grad_input,
......@@ -236,6 +509,14 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Te
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size, n1, n2,
grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma, grad_input);
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
......@@ -243,20 +524,6 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* in
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta) {
at::Tensor add_to_output = at::empty_like(*grad_input);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load_x((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::ScaleLoad<at::BFloat16, float, true> load_scaled_dy(
(at::BFloat16*)dout->data_ptr(), (at::BFloat16*)gamma->data_ptr(), n2);
fastfold::layer_norm::AddStore<float, at::BFloat16, true> store(
(at::BFloat16*)add_to_output.data_ptr(), (at::BFloat16*)grad_input->data_ptr(), n2);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNormGrad<decltype(load_x), decltype(load_scaled_dy),
decltype(store), float>(
cuda_stream, load_x, load_scaled_dy, store, (float*)mean->data_ptr(),
(float*)invvar->data_ptr(), n1, n2);
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel",
......
This diff is collapsed.
#include <torch/extension.h>
at::Tensor softmax(at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale);
......@@ -15,8 +15,8 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
int cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &softmax, "Softmax forward (CUDA)");
m.def("backward_affine", &softmax_gradient, "Softmax backward (CUDA)");
m.def("forward", &softmax, "Softmax forward (CUDA)");
m.def("backward", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_scale_mask_softmax_forward", &fused_scale_mask_softmax_forward,
"Softmax forward (CUDA)");
......
// modified from https://github.com/NVIDIA/apex
#include <ATen/ATen.h>
#include "compat.h"
......
......@@ -14,7 +14,7 @@ class SoftmaxAffineFunction(torch.autograd.Function):
input_ = input.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.forward_affine(input_, ctx.rows, ctx.cols)
output = fastfold_softmax_cuda.forward(input_, ctx.rows, ctx.cols)
ctx.save_for_backward(output)
return output
......@@ -25,7 +25,7 @@ class SoftmaxAffineFunction(torch.autograd.Function):
output = ctx.saved_tensors[0]
grad_input = None
grad_input = fastfold_softmax_cuda.backward_affine(grad_output.contiguous(), output,
grad_input = fastfold_softmax_cuda.backward(grad_output.contiguous(), output,
ctx.rows, ctx.cols)
return grad_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment