"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "e2d887206003b1909797672fd86dbb3ab1f1d3f9"
Unverified Commit 259b6c87 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

fix overflow for softmax kernel (#20)

parent 92e6cf49
#include <torch/extension.h> #include <torch/extension.h>
at::Tensor softmax(at::Tensor input, int rows, int cols); at::Tensor softmax(at::Tensor input, long long rows, long long cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols); at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols, at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols,
float scale); float scale);
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask, at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale); long long rows, long long cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias, at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale); long long rows, long long cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, int rows, at::Tensor mask, at::Tensor bias, long long rows,
int cols, float scale); long long cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax, "Softmax forward (CUDA)"); m.def("forward", &softmax, "Softmax forward (CUDA)");
......
...@@ -30,10 +30,10 @@ __inline__ __device__ float WarpAllReduceSum(float val) { ...@@ -30,10 +30,10 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
//////////////// ////////////////
__global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int cols) { __global__ void fastfold_softmax_fp32(float *input, float *output, long long rows, long long cols) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -83,11 +83,11 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int ...@@ -83,11 +83,11 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int
} }
} }
__global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, int rows, __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, long long rows,
int cols) { long long cols) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -139,11 +139,11 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output ...@@ -139,11 +139,11 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
} }
} }
__global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, int rows, __global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, long long rows,
int cols) { long long cols) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -191,10 +191,10 @@ __global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float ...@@ -191,10 +191,10 @@ __global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float
} }
__global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output, __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, int rows, int cols) { at::BFloat16 *d_input, long long rows, long long cols) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -242,7 +242,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 ...@@ -242,7 +242,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
} }
} }
at::Tensor softmax(at::Tensor input, int rows, int cols) { at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
CHECK_INPUT(input); CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
...@@ -262,7 +262,7 @@ at::Tensor softmax(at::Tensor input, int rows, int cols) { ...@@ -262,7 +262,7 @@ at::Tensor softmax(at::Tensor input, int rows, int cols) {
return output; return output;
} }
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols) { at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols) {
CHECK_INPUT(output); CHECK_INPUT(output);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
at::Tensor grad_input = at::empty_like(output); at::Tensor grad_input = at::empty_like(output);
...@@ -285,11 +285,11 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, in ...@@ -285,11 +285,11 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, in
//////////////// ////////////////
__global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, float *output, int rows, __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, float *output, long long rows,
int cols, float scale, int head) { long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -344,11 +344,11 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa ...@@ -344,11 +344,11 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
} }
__global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloat16 *mask, __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *output, int rows, int cols, at::BFloat16 *output, long long rows, long long cols,
float scale, int head) { float scale, int head) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -403,7 +403,7 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa ...@@ -403,7 +403,7 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
} }
} }
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols, at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols,
float scale) { float scale) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(mask); CHECK_INPUT(mask);
...@@ -428,11 +428,11 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, i ...@@ -428,11 +428,11 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, i
} }
__global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *output, __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *output,
float *d_input, float *mask, int rows, float *d_input, float *mask, long long rows,
int cols, float scale, int head) { long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -487,10 +487,10 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou ...@@ -487,10 +487,10 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
__global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output, __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
at::BFloat16 *d_input, at::BFloat16 *mask, at::BFloat16 *d_input, at::BFloat16 *mask,
int rows, int cols, float scale, int head) { long long rows, long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -544,7 +544,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a ...@@ -544,7 +544,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
} }
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output, at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, int rows, int cols, float scale) { at::Tensor mask, long long rows, long long cols, float scale) {
CHECK_INPUT(output); CHECK_INPUT(output);
CHECK_INPUT(mask); CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask)); const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
...@@ -571,11 +571,11 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out ...@@ -571,11 +571,11 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
//////////////// ////////////////
__global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, float *bias, __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, float *bias,
float *output, int rows, int cols, float *output, long long rows, long long cols,
float scale, int head) { float scale, int head) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -633,10 +633,10 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask, ...@@ -633,10 +633,10 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
__global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::BFloat16 *mask, __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
at::BFloat16 *bias, at::BFloat16 *output, at::BFloat16 *bias, at::BFloat16 *output,
int rows, int cols, float scale, int head) { long long rows, long long cols, float scale, int head) {
int threadidx_x = threadIdx.x / 32; int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32; int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x; long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32; int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread; int cols_this_thread = cols_per_thread;
...@@ -694,7 +694,7 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at:: ...@@ -694,7 +694,7 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
} }
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias, at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale) { long long rows, long long cols, float scale) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(mask); CHECK_INPUT(mask);
CHECK_INPUT(bias); CHECK_INPUT(bias);
...@@ -720,8 +720,8 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma ...@@ -720,8 +720,8 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
} }
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output, at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, at::Tensor bias, int rows, at::Tensor mask, at::Tensor bias, long long rows,
int cols, float scale) { long long cols, float scale) {
CHECK_INPUT(output); CHECK_INPUT(output);
CHECK_INPUT(mask); CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask)); const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment