Unverified Commit 27f6ab70 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

fixed bugs in softmax kernel and update unittest for softmax (#83)

* fixed bugs in softmax kernel and update unittest for softmax

* remove redundancy mask for softmax grad

* test both cuda/triton kernel in unittest
parent ea8bfcb0
...@@ -266,8 +266,8 @@ __global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long l ...@@ -266,8 +266,8 @@ __global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long l
cols_this_thread = 0; cols_this_thread = 0;
} }
float y_buf[8]; float y_buf[32];
float dy_buf[8]; float dy_buf[32];
int lane_id = threadidx_y; int lane_id = threadidx_y;
...@@ -280,23 +280,29 @@ __global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long l ...@@ -280,23 +280,29 @@ __global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long l
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]); if (lane_id * cols_per_thread + i < cols) {
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]); y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]);
}
} }
float thread_sum = 0.f; float thread_sum = 0.f;
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i]; if (lane_id * cols_per_thread + i < cols) {
thread_sum += y_buf[i] * dy_buf[i];
}
} }
float warp_sum = WarpAllReduceSum(thread_sum); float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; ++i) { for (int i = 0; i < cols_this_thread; ++i) {
row_d_input[lane_id * cols_per_thread + i] = if (lane_id * cols_per_thread + i < cols) {
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]); row_d_input[lane_id * cols_per_thread + i] =
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]);
}
} }
} }
} }
...@@ -346,10 +352,14 @@ __global__ void fastfold_softmax_mask(T *input, T *mask, T *output, long long ro ...@@ -346,10 +352,14 @@ __global__ void fastfold_softmax_mask(T *input, T *mask, T *output, long long ro
#pragma unroll #pragma unroll
for (int i = 0; i < cols_per_thread; i++) { for (int i = 0; i < cols_per_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) { if (lane_id * cols_per_thread + i < cols) {
buf[i] = -1 * 1e9; if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 1e9;
} else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]);
}
} else { } else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]); buf[i] = -1 * CUDART_INF_F;
} }
} }
...@@ -371,7 +381,9 @@ __global__ void fastfold_softmax_mask(T *input, T *mask, T *output, long long ro ...@@ -371,7 +381,9 @@ __global__ void fastfold_softmax_mask(T *input, T *mask, T *output, long long ro
float warp_sum = WarpAllReduceSum(thread_sum); float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll #pragma unroll
for (int i = 0; i < cols_per_thread; ++i) { for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum)); if (lane_id * cols_per_thread + i < cols) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum));
}
} }
} }
...@@ -526,8 +538,8 @@ __global__ void fastfold_softmax_mask_grad(T *d_output, T *output, T *d_input, T ...@@ -526,8 +538,8 @@ __global__ void fastfold_softmax_mask_grad(T *d_output, T *output, T *d_input, T
cols_this_thread = 0; cols_this_thread = 0;
} }
float y_buf[8]; float y_buf[32];
float dy_buf[8]; float dy_buf[32];
int lane_id = threadidx_y; int lane_id = threadidx_y;
...@@ -541,26 +553,32 @@ __global__ void fastfold_softmax_mask_grad(T *d_output, T *output, T *d_input, T ...@@ -541,26 +553,32 @@ __global__ void fastfold_softmax_mask_grad(T *d_output, T *output, T *d_input, T
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]); if (lane_id * cols_per_thread + i < cols) {
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]); y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]);
}
} }
float thread_sum = 0.f; float thread_sum = 0.f;
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; i++) { for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i]; if (lane_id * cols_per_thread + i < cols) {
thread_sum += y_buf[i] * dy_buf[i];
}
} }
float warp_sum = WarpAllReduceSum(thread_sum); float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll #pragma unroll
for (int i = 0; i < cols_this_thread; ++i) { for (int i = 0; i < cols_this_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) { if (lane_id * cols_per_thread + i < cols) {
row_d_input[lane_id * cols_per_thread + i] = if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]); row_d_input[lane_id * cols_per_thread + i] =
} else { static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]);
row_d_input[lane_id * cols_per_thread + i] = 0; } else {
row_d_input[lane_id * cols_per_thread + i] = 0;
}
} }
} }
} }
...@@ -615,11 +633,15 @@ __global__ void fastfold_softmax_mask_bias(T *input, T *mask, T *bias, T *output ...@@ -615,11 +633,15 @@ __global__ void fastfold_softmax_mask_bias(T *input, T *mask, T *bias, T *output
#pragma unroll #pragma unroll
for (int i = 0; i < cols_per_thread; i++) { for (int i = 0; i < cols_per_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) { if (lane_id * cols_per_thread + i < cols) {
buf[i] = -1 * 10e9; if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 10e9;
} else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) +
static_cast<T>(bias_ptr[lane_id * cols_per_thread + i]);
}
} else { } else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) + buf[i] = -1 * CUDART_INF_F;
static_cast<T>(bias_ptr[lane_id * cols_per_thread + i]);
} }
} }
...@@ -641,7 +663,9 @@ __global__ void fastfold_softmax_mask_bias(T *input, T *mask, T *bias, T *output ...@@ -641,7 +663,9 @@ __global__ void fastfold_softmax_mask_bias(T *input, T *mask, T *bias, T *output
float warp_sum = WarpAllReduceSum(thread_sum); float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll #pragma unroll
for (int i = 0; i < cols_per_thread; ++i) { for (int i = 0; i < cols_per_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum)); if (lane_id * cols_per_thread + i < cols) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum));
}
} }
} }
......
...@@ -45,8 +45,7 @@ class FusedSoftmaxFunc(torch.autograd.Function): ...@@ -45,8 +45,7 @@ class FusedSoftmaxFunc(torch.autograd.Function):
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
output, mask_ = ctx.saved_tensors output, mask_ = ctx.saved_tensors
if _triton_available: if _triton_available:
grad_input = softmax_grad_triton_kernel_wrapper(grad_output, output, mask_, ctx.rows, grad_input = softmax_grad_triton_kernel_wrapper(grad_output, output, ctx.rows, ctx.cols)
ctx.cols)
else: else:
grad_input = softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, ctx.rows, grad_input = softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, ctx.rows,
ctx.cols) ctx.cols)
......
...@@ -27,10 +27,10 @@ def _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_ ...@@ -27,10 +27,10 @@ def _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_
@triton.jit @triton.jit
def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols, def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_cols,
is_bf16: tl.constexpr, use_mask: tl.constexpr): is_bf16: tl.constexpr):
output_row = tl.load(output_ptrs, mask=col_offsets < n_cols, other=float("-inf")) output_row = tl.load(output_ptrs, mask=col_offsets < n_cols, other=float(0))
d_output_row = tl.load(d_output_ptrs, mask=col_offsets < n_cols, other=float("-inf")) d_output_row = tl.load(d_output_ptrs, mask=col_offsets < n_cols, other=float(0))
if is_bf16: if is_bf16:
output_row = output_row.to(tl.float32) output_row = output_row.to(tl.float32)
...@@ -39,10 +39,6 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_ ...@@ -39,10 +39,6 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_
row_sum = tl.sum(output_row * d_output_row, axis=0) row_sum = tl.sum(output_row * d_output_row, axis=0)
d_softmax_output = (d_output_row - row_sum) * output_row d_softmax_output = (d_output_row - row_sum) * output_row
if use_mask:
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
d_softmax_output = tl.where(mask == 0, float(0), d_softmax_output)
tl.store(d_input_ptrs, d_softmax_output, mask=col_offsets < n_cols) tl.store(d_input_ptrs, d_softmax_output, mask=col_offsets < n_cols)
...@@ -114,10 +110,9 @@ def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, ...@@ -114,10 +110,9 @@ def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr,
@triton.jit @triton.jit
def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_output_row_stride, def softmax_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, d_output_row_stride,
output_row_stride, d_input_row_stride, n_cols, n_heads, output_row_stride, d_input_row_stride, n_cols, BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr, is_bf16: tl.constexpr):
use_mask: tl.constexpr):
row_idx = tl.program_id(0).to(tl.int64) row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
...@@ -130,20 +125,13 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_ ...@@ -130,20 +125,13 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_
d_output_ptrs = d_output_row_ptr + col_offsets d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets d_input_ptrs = d_input_row_ptr + col_offsets
mask_ptrs = output_ptrs # place holder, not use if use_mask == False _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_cols, is_bf16)
if use_mask:
mask_row_ptr = mask_ptr + (row_idx // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16, use_mask)
@triton.jit @triton.jit
def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, def softmax_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, d_output_row_stride,
d_output_row_stride, output_row_stride, d_input_row_stride, output_row_stride, d_input_row_stride, n_cols,
n_cols, n_heads, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr):
is_bf16: tl.constexpr, use_mask: tl.constexpr):
row_idx = tl.program_id(0).to(tl.int64) row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
...@@ -156,21 +144,10 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas ...@@ -156,21 +144,10 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas
d_output_ptrs = d_output_row_ptr + col_offsets d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets d_input_ptrs = d_input_row_ptr + col_offsets
mask_ptrs = output_ptrs # place holder, not use if use_mask == False _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_cols, is_bf16)
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16, use_mask)
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx + 1) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs + n_cols, d_output_ptrs + n_cols, d_input_ptrs + n_cols, _softmax_grad_core(output_ptrs + n_cols, d_output_ptrs + n_cols, d_input_ptrs + n_cols,
mask_ptrs, col_offsets, n_cols, is_bf16, use_mask) col_offsets, n_cols, is_bf16)
def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols): def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols):
...@@ -209,9 +186,8 @@ def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols): ...@@ -209,9 +186,8 @@ def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols):
return y return y
def softmax_grad_triton_kernel_wrapper(grad_output, output, mask, n_rows, n_cols): def softmax_grad_triton_kernel_wrapper(grad_output, output, n_rows, n_cols):
grad_input = torch.empty_like(grad_output) grad_input = torch.empty_like(grad_output)
n_heads = output.shape[2]
num_warps = 1 num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols) BLOCK_SIZE = triton.next_power_of_2(n_cols)
...@@ -223,25 +199,22 @@ def softmax_grad_triton_kernel_wrapper(grad_output, output, mask, n_rows, n_cols ...@@ -223,25 +199,22 @@ def softmax_grad_triton_kernel_wrapper(grad_output, output, mask, n_rows, n_cols
num_warps = 16 num_warps = 16
is_bf16 = (output.dtype == torch.bfloat16) is_bf16 = (output.dtype == torch.bfloat16)
_dispatch_kernel = softmax_mask_grad_kernel _dispatch_kernel = softmax_grad_kernel
_grid = (n_rows,) _grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0: if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_mask_grad_kernel_two_rows _dispatch_kernel = softmax_grad_kernel_two_rows
_grid = (n_rows // 2,) _grid = (n_rows // 2,)
_dispatch_kernel[_grid]( _dispatch_kernel[_grid](
grad_output, grad_output,
output, output,
grad_input, grad_input,
mask,
grad_output.stride(-2), grad_output.stride(-2),
output.stride(-2), output.stride(-2),
grad_output.stride(-2), grad_output.stride(-2),
n_cols, n_cols,
n_heads,
num_warps=num_warps, num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
is_bf16=is_bf16, is_bf16=is_bf16,
use_mask=(mask != None),
) )
return grad_input return grad_input
import torch import torch
from fastfold.model.fastnn.kernel import fused_softmax
from fastfold.model.fastnn.kernel import softmax from fastfold.model.fastnn.kernel import softmax
def test_softmax(): def _test_softmax_core():
# [batch, dim] batch_, chunk_, head_ = 1, 8, 4
test_shape = [[64, 64], [64, 128], [64, 129], [64, 2000]] test_seq_ = [31, 32, 128, 129, 256, 259, 512, 700, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16] test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda") test_device = torch.device("cuda")
tolerance_eps = {torch.float32: 10e-4, torch.float16: 10e-2, torch.bfloat16: 10e-2} tolerance_eps = {torch.float32: 1e-6, torch.float16: 2e-4, torch.bfloat16: 1e-3}
for shape in test_shape: for seq_ in test_seq_:
for dtype in test_dtype: for dtype in test_dtype:
sample_input = torch.rand(shape).to(device=test_device, sample_input = torch.rand(batch_, chunk_, head_, seq_,
dtype=dtype).requires_grad_(True) seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
sample_mask = torch.cuda.FloatTensor(batch_, chunk_, seq_).uniform_() > 0
sample_mask = sample_mask.to(device=test_device, dtype=dtype).requires_grad_(False)
sample_bias = torch.rand(batch_, 1, head_, seq_,
seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True) sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True)
sample_mask_fastnn = torch.clone(sample_mask.detach()).requires_grad_(False)
sample_bias_fastnn = torch.clone(sample_bias.detach()).requires_grad_(True)
# Forward # Forward
torch_out = torch.nn.functional.softmax(sample_input, dim=-1) sample_mask_torch = 1e9 * (sample_mask - 1)[:, :, None, None, :]
fastnn_out = softmax(sample_input_fastnn) torch_out = torch.nn.functional.softmax(sample_input + sample_mask_torch + sample_bias,
forward_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item() dim=-1)
assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
fastnn_out = fused_softmax(sample_input_fastnn, sample_mask_fastnn, sample_bias_fastnn)
fwd_fastnn_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
assert fwd_fastnn_error < tolerance_eps[
dtype], f"fastnn fwd kernel error when {seq_} {dtype}"
# Backward # Backward
out_grad = torch.rand_like(torch_out).requires_grad_(False) out_grad = torch.rand_like(torch_out).requires_grad_(False)
torch_out.backward(out_grad) torch_out.backward(out_grad)
fastnn_out.backward(out_grad) fastnn_out.backward(out_grad)
backward_error = torch.max(torch.abs(sample_input.grad - grad_input_error = torch.max(torch.abs(sample_input.grad -
sample_input_fastnn.grad)).cpu().item() sample_input_fastnn.grad)).cpu().item()
assert backward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}" assert grad_input_error < tolerance_eps[
dtype], f"fastnn bwd kernel error when {seq_} {dtype}"
grad_bias_error = torch.max(torch.abs(sample_bias.grad -
sample_bias_fastnn.grad)).cpu().item()
assert grad_bias_error < tolerance_eps[
dtype], f"fastnn bwd kernel error when {seq_} {dtype}"
def test_softmax():
_test_softmax_core()
if softmax._triton_available:
softmax._triton_available = False
_test_softmax_core()
if __name__ == "__main__":
test_softmax()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment