Unverified Commit a65d5009 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

use template for fused softmax & add unittest for fused softmax (#26)

parent 771d4b83
#include <c10/cuda/CUDAGuard.h>
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
......@@ -30,7 +30,8 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
////////////////
__global__ void fastfold_softmax_fp32(float *input, float *output, long long rows, long long cols) {
template <typename T>
__global__ void fastfold_softmax(T *input, T *output, long long rows, long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -41,8 +42,7 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
......@@ -51,72 +51,17 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
T *row_input = input + row_offset * cols;
T *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
}
__global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, long long rows,
long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
......@@ -124,74 +69,47 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
__global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, long long rows,
long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_d_output = d_output + row_offset * cols;
float *row_output = output + row_offset * cols;
float *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = row_output[lane_id * cols_per_thread + i];
dy_buf[i] = row_d_output[lane_id * cols_per_thread + i];
}
float thread_sum = 0.f;
at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
at::Tensor output = at::empty_like(input);
float warp_sum = WarpAllReduceSum(thread_sum);
int grid = (rows + 3) / 4;
dim3 block(128);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_d_input[lane_id * cols_this_thread + i] = (dy_buf[i] - warp_sum) * y_buf[i];
}
if (input.dtype() == torch::kFloat32) {
fastfold_softmax<float>
<<<grid, block>>>((float *)input.data_ptr(), (float *)output.data_ptr(), rows, cols);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax<at::Half><<<grid, block>>>((at::Half *)input.data_ptr(),
(at::Half *)output.data_ptr(), rows, cols);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols);
}
return output;
}
__global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, long long rows, long long cols) {
template <typename T>
__global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long long rows,
long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -202,8 +120,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
......@@ -213,56 +130,37 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_d_output = d_output + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *row_d_input = d_input + row_offset * cols;
T *row_d_output = d_output + row_offset * cols;
T *row_output = output + row_offset * cols;
T *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<float>(row_d_output[lane_id * cols_per_thread + i]);
y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.f;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>((dy_buf[i] - warp_sum) * y_buf[i]);
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]);
}
}
}
at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor output = at::empty_like(input);
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_fp32<<<grid, block>>>((float *)input.data_ptr(),
(float *)output.data_ptr(), rows, cols);
} else {
fastfold_softmax_bfp16<<<grid, block>>>((at::BFloat16 *)input.data_ptr(),
(at::BFloat16 *)output.data_ptr(), rows, cols);
}
return output;
}
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols) {
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows,
long long cols) {
CHECK_INPUT(output);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
at::Tensor grad_input = at::empty_like(output);
......@@ -271,11 +169,15 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_grad_fp32<<<grid, block>>>((float *)d_output.data_ptr(),
(float *)output.data_ptr(),
(float *)grad_input.data_ptr(), rows, cols);
} else {
fastfold_softmax_grad_bfp16<<<grid, block>>>(
fastfold_softmax_grad<float><<<grid, block>>>((float *)d_output.data_ptr(),
(float *)output.data_ptr(),
(float *)grad_input.data_ptr(), rows, cols);
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_grad<at::Half>
<<<grid, block>>>((at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), rows, cols);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), rows, cols);
}
......@@ -285,8 +187,9 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
__global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, float *output, long long rows,
long long cols, float scale, int head) {
template <typename T>
__global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long long rows,
long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -297,8 +200,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
......@@ -307,80 +209,21 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
T *row_input = input + row_offset * cols;
T *row_output = output + row_offset * cols;
T *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 1e9;
} else {
buf[i] = row_input[lane_id * cols_per_thread + i] * scale;
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
}
__global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *output, long long rows, long long cols,
float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 10e9;
} else {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) * scale;
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
......@@ -388,23 +231,23 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols,
float scale) {
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
long long cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
......@@ -415,79 +258,26 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_fp32<<<grid, block>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)output.data_ptr(), rows,
cols, scale, head);
} else {
fastfold_softmax_scale_mask_bfp16<<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)output.data_ptr(), rows, cols, scale, head);
fastfold_softmax_scale_mask<float>
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(),
(float *)output.data_ptr(), rows, cols, scale, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask<at::Half>
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)output.data_ptr(), rows, cols, scale, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask<at::BFloat16>
<<<grid, block>>>((at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)output.data_ptr(), rows, cols, scale, head);
}
return output;
}
__global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *output,
float *d_input, float *mask, long long rows,
long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_d_output = d_output + row_offset * cols;
float *row_output = output + row_offset * cols;
float *row_d_input = d_input + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = row_output[lane_id * cols_per_thread + i];
dy_buf[i] = row_d_output[lane_id * cols_per_thread + i];
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
scale * ((dy_buf[i] - warp_sum) * y_buf[i]);
} else {
row_d_input = 0;
}
}
}
}
__global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, at::BFloat16 *mask,
long long rows, long long cols, float scale, int head) {
template <typename T>
__global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_input, T *mask,
long long rows, long long cols, float scale,
int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -498,8 +288,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
......@@ -509,33 +298,33 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_d_output = d_output + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *row_d_input = d_input + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
T *row_d_output = d_output + row_offset * cols;
T *row_output = output + row_offset * cols;
T *row_d_input = d_input + row_offset * cols;
T *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<float>(row_d_output[lane_id * cols_per_thread + i]);
y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.f;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(scale * ((dy_buf[i] - warp_sum) * y_buf[i]));
static_cast<T>(scale * ((dy_buf[i] - warp_sum) * y_buf[i]));
} else {
row_d_input = 0;
}
......@@ -544,7 +333,8 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, long long rows, long long cols, float scale) {
at::Tensor mask, long long rows, long long cols,
float scale) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
......@@ -555,11 +345,16 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_grad_fp32<<<grid, block>>>(
fastfold_softmax_scale_mask_grad<float><<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, scale, head);
} else {
fastfold_softmax_scale_mask_grad_bfp16<<<grid, block>>>(
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_grad<at::Half>
<<<grid, block>>>((at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows,
cols, scale, head);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
scale, head);
......@@ -570,9 +365,10 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
__global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, float *bias,
float *output, long long rows, long long cols,
float scale, int head) {
template <typename T>
__global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *output,
long long rows, long long cols, float scale,
int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -583,69 +379,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 10e9;
} else {
buf[i] = row_input[lane_id * cols_per_thread + i] * scale +
bias_ptr[lane_id * cols_per_thread + i];
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
}
__global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *bias, at::BFloat16 *output,
long long rows, long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
......@@ -654,23 +388,23 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
at::BFloat16 *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
T *row_input = input + row_offset * cols;
T *row_output = output + row_offset * cols;
T *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
T *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 10e9;
} else {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] += static_cast<float>(bias_ptr[lane_id * cols_per_thread + i]);
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] += static_cast<T>(bias_ptr[lane_id * cols_per_thread + i]);
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
......@@ -678,17 +412,17 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
......@@ -706,14 +440,18 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_bias_fp32<<<grid, block>>>(
fastfold_softmax_scale_mask_bias<float><<<grid, block>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(),
(float *)output.data_ptr(), rows, cols, scale, head);
} else {
fastfold_softmax_scale_mask_bias_bfp16<<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols, scale,
head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_bias<at::Half><<<grid, block>>>(
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), (at::Half *)bias.data_ptr(),
(at::Half *)output.data_ptr(), rows, cols, scale, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_bias<at::BFloat16>
<<<grid, block>>>((at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)output.data_ptr(),
rows, cols, scale, head);
}
return output;
......@@ -732,11 +470,16 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_grad_fp32<<<grid, block>>>(
fastfold_softmax_scale_mask_grad<float><<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, scale, head);
} else {
fastfold_softmax_scale_mask_grad_bfp16<<<grid, block>>>(
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_grad<at::Half>
<<<grid, block>>>((at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows,
cols, scale, head);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
scale, head);
......
import torch
from fastfold.model.fastnn.kernel import softmax
def test_softmax():
# [batch, dim]
test_shape = [[64, 64], [64, 128], [64, 129], [64, 1024]]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
tolerance_eps = {torch.float32: 10e-5, torch.float16: 10e-2, torch.bfloat16: 10e-2}
for shape in test_shape:
for dtype in test_dtype:
sample_input = torch.rand(shape).to(device=test_device,
dtype=dtype).requires_grad_(True)
sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True)
# Forward
torch_out = torch.nn.functional.softmax(sample_input, dim=-1)
fastnn_out = softmax(sample_input_fastnn)
forward_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
# Backward
out_grad = torch.rand_like(torch_out).requires_grad_(False)
torch_out.backward(out_grad)
fastnn_out.backward(out_grad)
backward_error = torch.max(torch.abs(sample_input.grad -
sample_input_fastnn.grad)).cpu().item()
assert backward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment