Unverified Commit ded582b2 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

remove scale in fused softmax kernel (#34)

parent ad7f0cb5
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm
from .cuda_native.softmax import softmax, scale_mask_softmax, scale_mask_bias_softmax
from .cuda_native.softmax import softmax, mask_softmax, mask_bias_softmax
__all__ = [
"bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax",
"scale_mask_softmax", "scale_mask_bias_softmax"
"mask_softmax", "mask_bias_softmax"
]
\ No newline at end of file
......@@ -3,28 +3,25 @@
at::Tensor softmax(at::Tensor input, long long rows, long long cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols,
float scale);
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
long long rows, long long cols, float scale);
at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
long long cols);
at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
long long rows, long long cols);
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, long long rows,
long long cols, float scale);
at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols);
at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
at::Tensor bias, long long rows, long long cols);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax, "Softmax forward (CUDA)");
m.def("backward", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_scale_mask_softmax_forward", &fused_scale_mask_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_softmax_backward", &fused_scale_mask_softmax_backward,
"Softmax forward (CUDA)");
m.def("fused_mask_softmax_forward", &fused_mask_softmax_forward, "Softmax forward (CUDA)");
m.def("fused_mask_softmax_backward", &fused_mask_softmax_backward, "Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_forward", &fused_scale_mask_bias_softmax_forward,
m.def("fused_mask_bias_softmax_forward", &fused_mask_bias_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_backward", &fused_scale_mask_bias_softmax_backward,
m.def("fused_mask_bias_softmax_backward", &fused_mask_bias_softmax_backward,
"Softmax forward (CUDA)");
}
\ No newline at end of file
......@@ -330,8 +330,8 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
template <typename T, int cols_per_thread>
__global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long long rows,
long long cols, float scale, int head) {
__global__ void fastfold_softmax_mask(T *input, T *mask, T *output, long long rows, long long cols,
int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -349,7 +349,7 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 1e9;
} else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]);
}
}
......@@ -376,8 +376,8 @@ __global__ void fastfold_softmax_scale_mask(T *input, T *mask, T *output, long l
}
template <typename T, int block_size>
__global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, long long rows,
long long cols, float scale, int head) {
__global__ void fastfold_softmax_mask_sm(T *input, T *mask, T *output, long long rows,
long long cols, int head) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<float *>(shared_buf);
const int tid = threadIdx.x;
......@@ -389,7 +389,7 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
if (mask_ptr[id] == 0) {
buf[id] = -1 * 1e9;
} else {
buf[id] = input[row * cols + id] * scale;
buf[id] = input[row * cols + id];
}
thread_max = max(thread_max, buf[id]);
}
......@@ -410,8 +410,8 @@ __global__ void fastfold_softmax_scale_mask_sm(T *input, T *mask, T *output, lon
}
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
long long cols, float scale) {
at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
long long cols) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
......@@ -423,33 +423,33 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
if (cols <= 32) {
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask<float, 1>
fastfold_softmax_mask<float, 1>
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(),
(float *)input.data_ptr(), rows, cols, scale, head);
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask<at::Half, 1>
fastfold_softmax_mask<at::Half, 1>
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)input.data_ptr(), rows, cols, scale, head);
(at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask<at::BFloat16, 1>
fastfold_softmax_mask<at::BFloat16, 1>
<<<grid, block>>>((at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)input.data_ptr(), rows, cols, scale, head);
(at::BFloat16 *)input.data_ptr(), rows, cols, head);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_scale_mask<float, col_per_thread> \
fastfold_softmax_mask<float, col_per_thread> \
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
(float *)input.data_ptr(), rows, cols, scale, head); \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_scale_mask<at::Half, col_per_thread> \
fastfold_softmax_mask<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), rows, cols, scale, head); \
(at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_scale_mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
fastfold_softmax_mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), rows, cols, scale, head); \
(at::BFloat16 *)input.data_ptr(), rows, cols, head); \
} \
}
COLS_CASE(2)
......@@ -493,26 +493,25 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
const size_t smem = cols * sizeof(float);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_sm<float, 128>
fastfold_softmax_mask_sm<float, 128>
<<<grid, block, smem>>>((float *)input.data_ptr(), (float *)mask.data_ptr(),
(float *)input.data_ptr(), rows, cols, scale, head);
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_sm<at::Half, 128>
fastfold_softmax_mask_sm<at::Half, 128>
<<<grid, block, smem>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)input.data_ptr(), rows, cols, scale, head);
(at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_sm<at::BFloat16, 128><<<grid, block, smem>>>(
fastfold_softmax_mask_sm<at::BFloat16, 128><<<grid, block, smem>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)input.data_ptr(), rows, cols, scale, head);
(at::BFloat16 *)input.data_ptr(), rows, cols, head);
}
}
return input;
}
template <typename T>
__global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_input, T *mask,
long long rows, long long cols, float scale,
int head) {
__global__ void fastfold_softmax_mask_grad(T *d_output, T *output, T *d_input, T *mask,
long long rows, long long cols, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -559,7 +558,7 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
for (int i = 0; i < cols_this_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<T>(scale * ((dy_buf[i] - warp_sum) * y_buf[i]));
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]);
} else {
row_d_input[lane_id * cols_per_thread + i] = 0;
}
......@@ -567,9 +566,8 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
}
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, long long rows, long long cols,
float scale) {
at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor output, at::Tensor mask,
long long rows, long long cols) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
......@@ -580,19 +578,18 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_grad<float><<<grid, block>>>(
fastfold_softmax_mask_grad<float><<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, scale, head);
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_grad<at::Half>
<<<grid, block>>>((at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows,
cols, scale, head);
fastfold_softmax_mask_grad<at::Half><<<grid, block>>>(
(at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_grad<at::BFloat16><<<grid, block>>>(
fastfold_softmax_mask_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
scale, head);
head);
}
return grad_input;
......@@ -601,9 +598,8 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
template <typename T, int cols_per_thread>
__global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *output,
long long rows, long long cols, float scale,
int head) {
__global__ void fastfold_softmax_mask_bias(T *input, T *mask, T *bias, T *output, long long rows,
long long cols, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
......@@ -622,8 +618,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 10e9;
} else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) * scale;
buf[i] += static_cast<T>(bias_ptr[lane_id * cols_per_thread + i]);
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) +
static_cast<T>(bias_ptr[lane_id * cols_per_thread + i]);
}
}
......@@ -650,9 +646,8 @@ __global__ void fastfold_softmax_scale_mask_bias(T *input, T *mask, T *bias, T *
}
template <typename T, int block_size>
__global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias, T *output,
long long rows, long long cols, float scale,
int head) {
__global__ void fastfold_softmax_mask_bias_sm(T *input, T *mask, T *bias, T *output, long long rows,
long long cols, int head) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<float *>(shared_buf);
const int tid = threadIdx.x;
......@@ -665,7 +660,7 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
if (mask_ptr[id] == 0) {
buf[id] = -1 * 1e9;
} else {
buf[id] = input[row * cols + id] * scale + bias_ptr[id];
buf[id] = input[row * cols + id] + bias_ptr[id];
}
thread_max = max(thread_max, buf[id]);
}
......@@ -686,8 +681,8 @@ __global__ void fastfold_softmax_scale_mask_bias_sm(T *input, T *mask, T *bias,
}
}
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols, float scale) {
at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
CHECK_INPUT(bias);
......@@ -700,37 +695,36 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
if (cols <= 32) {
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_bias<float, 1><<<grid, block>>>(
fastfold_softmax_mask_bias<float, 1><<<grid, block>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(),
(float *)input.data_ptr(), rows, cols, scale, head);
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_bias<at::Half, 1><<<grid, block>>>(
fastfold_softmax_mask_bias<at::Half, 1><<<grid, block>>>(
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, scale, head);
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_bias<at::BFloat16, 1>
fastfold_softmax_mask_bias<at::BFloat16, 1>
<<<grid, block>>>((at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(),
rows, cols, scale, head);
rows, cols, head);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_scale_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, scale, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_scale_mask_bias<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, \
cols, scale, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_scale_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
scale, head); \
} \
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_mask_bias<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
head); \
} \
}
COLS_CASE(2)
COLS_CASE(3)
......@@ -773,27 +767,26 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
const size_t smem = cols * sizeof(float);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_bias_sm<float, 128><<<grid, block, smem>>>(
fastfold_softmax_mask_bias_sm<float, 128><<<grid, block, smem>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(),
(float *)input.data_ptr(), rows, cols, scale, head);
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_bias_sm<at::Half, 128><<<grid, block, smem>>>(
fastfold_softmax_mask_bias_sm<at::Half, 128><<<grid, block, smem>>>(
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, scale, head);
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_bias_sm<at::BFloat16, 128><<<grid, block, smem>>>(
fastfold_softmax_mask_bias_sm<at::BFloat16, 128><<<grid, block, smem>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols,
scale, head);
head);
}
}
return input;
}
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, at::Tensor bias, long long rows,
long long cols, float scale) {
at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output, at::Tensor mask,
at::Tensor bias, long long rows, long long cols) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
......@@ -804,19 +797,18 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_scale_mask_grad<float><<<grid, block>>>(
fastfold_softmax_mask_grad<float><<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, scale, head);
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_scale_mask_grad<at::Half>
<<<grid, block>>>((at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows,
cols, scale, head);
fastfold_softmax_mask_grad<at::Half><<<grid, block>>>(
(at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_scale_mask_grad<at::BFloat16><<<grid, block>>>(
fastfold_softmax_mask_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
scale, head);
head);
}
return grad_input;
......
......@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function):
return grad_input
class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
class FusedMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, scale):
def forward(ctx, input, mask):
input_ = input.contiguous()
mask_ = mask.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols, scale)
output = fastfold_softmax_cuda.fused_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_)
ctx.scale = scale
return output
......@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
output, mask_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols, ctx.scale)
grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols)
return grad_input.contiguous(), None, None
class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
class FusedMaskBiasSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias, scale):
def forward(ctx, input, mask, bias):
input_ = input.contiguous()
mask_ = mask.contiguous()
bias_ = bias.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols, scale)
output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_, bias_)
ctx.scale = scale
return output
......@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
output, mask_, bias_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols, ctx.scale)
grad_input = fastfold_softmax_cuda.fused_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols)
grad_input = grad_input.contiguous()
......@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
softmax = SoftmaxAffineFunction.apply
scale_mask_softmax = FusedScaleMaskSoftmaxFunction.apply
scale_mask_bias_softmax = FusedScaleMaskBiasSoftmaxFunction.apply
mask_softmax = FusedMaskSoftmaxFunction.apply
mask_bias_softmax = FusedMaskBiasSoftmaxFunction.apply
......@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fastfold.model.fastnn.kernel import scale_mask_softmax, scale_mask_bias_softmax
from fastfold.model.fastnn.kernel import mask_softmax, mask_bias_softmax
from fastfold.model.fastnn.kernel import LayerNorm
from .initializer import glorot_uniform_af
......@@ -160,26 +160,17 @@ class SelfAttention(nn.Module):
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
# q = self.to_q(in_data)
# k = self.to_k(in_data)
# v = self.to_k(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), [q, k, v])
# q = q * self.scaling
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
# logits += mask
if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k')
weights = scale_mask_bias_softmax(logits, mask, bias.unsqueeze(1), self.scaling)
weights = mask_bias_softmax(logits, mask, bias.unsqueeze(1))
else:
weights = scale_mask_softmax(logits, mask, self.scaling)
# weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weights = mask_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment