"tools/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "704b50e24a69eddea0aa94d3371d6773ccb48cb8"
Unverified Commit 259b6c87 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

fix overflow for softmax kernel (#20)

parent 92e6cf49
#include <torch/extension.h>
at::Tensor softmax(at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols);
at::Tensor softmax(at::Tensor input, long long rows, long long cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols,
float scale);
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale);
long long rows, long long cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale);
long long rows, long long cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale);
at::Tensor mask, at::Tensor bias, long long rows,
long long cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax, "Softmax forward (CUDA)");
......
......@@ -30,10 +30,10 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
////////////////
__global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int cols) {
__global__ void fastfold_softmax_fp32(float *input, float *output, long long rows, long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -83,11 +83,11 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int
}
}
__global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, int rows,
int cols) {
__global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, long long rows,
long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -139,11 +139,11 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
}
}
__global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, int rows,
int cols) {
__global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, long long rows,
long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -191,10 +191,10 @@ __global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float
}
__global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, int rows, int cols) {
at::BFloat16 *d_input, long long rows, long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -242,7 +242,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
}
}
at::Tensor softmax(at::Tensor input, int rows, int cols) {
at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
......@@ -262,7 +262,7 @@ at::Tensor softmax(at::Tensor input, int rows, int cols) {
return output;
}
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols) {
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols) {
CHECK_INPUT(output);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
at::Tensor grad_input = at::empty_like(output);
......@@ -285,11 +285,11 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, in
////////////////
__global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, float *output, int rows,
int cols, float scale, int head) {
__global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, float *output, long long rows,
long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -344,11 +344,11 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
}
__global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *output, int rows, int cols,
at::BFloat16 *output, long long rows, long long cols,
float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -403,7 +403,7 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
}
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols,
float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
......@@ -428,11 +428,11 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, i
}
__global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *output,
float *d_input, float *mask, int rows,
int cols, float scale, int head) {
float *d_input, float *mask, long long rows,
long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -487,10 +487,10 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
__global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, at::BFloat16 *mask,
int rows, int cols, float scale, int head) {
long long rows, long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -544,7 +544,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, int rows, int cols, float scale) {
at::Tensor mask, long long rows, long long cols, float scale) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
......@@ -571,11 +571,11 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
__global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, float *bias,
float *output, int rows, int cols,
float *output, long long rows, long long cols,
float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -633,10 +633,10 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
__global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *bias, at::BFloat16 *output,
int rows, int cols, float scale, int head) {
long long rows, long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
......@@ -694,7 +694,7 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
}
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale) {
long long rows, long long cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
CHECK_INPUT(bias);
......@@ -720,8 +720,8 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
}
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale) {
at::Tensor mask, at::Tensor bias, long long rows,
long long cols, float scale) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment