"testing/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "3471904f682d2c8149a480e1c9bc8b3736fb6dd9"
Unverified Commit ded582b2 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

remove scale in fused softmax kernel (#34)

parent ad7f0cb5
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm
from .cuda_native.softmax import softmax, scale_mask_softmax, scale_mask_bias_softmax from .cuda_native.softmax import softmax, mask_softmax, mask_bias_softmax
__all__ = [ __all__ = [
"bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax", "bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax",
"scale_mask_softmax", "scale_mask_bias_softmax" "mask_softmax", "mask_bias_softmax"
] ]
\ No newline at end of file
...@@ -3,28 +3,25 @@ ...@@ -3,28 +3,25 @@
at::Tensor softmax(at::Tensor input, long long rows, long long cols); at::Tensor softmax(at::Tensor input, long long rows, long long cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols); at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows, long long cols, at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
float scale); long long cols);
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask, at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
long long rows, long long cols, float scale); long long rows, long long cols);
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias, at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols, float scale); long long rows, long long cols);
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
at::Tensor mask, at::Tensor bias, long long rows, at::Tensor bias, long long rows, long long cols);
long long cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax, "Softmax forward (CUDA)"); m.def("forward", &softmax, "Softmax forward (CUDA)");
m.def("backward", &softmax_gradient, "Softmax backward (CUDA)"); m.def("backward", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_scale_mask_softmax_forward", &fused_scale_mask_softmax_forward, m.def("fused_mask_softmax_forward", &fused_mask_softmax_forward, "Softmax forward (CUDA)");
"Softmax forward (CUDA)"); m.def("fused_mask_softmax_backward", &fused_mask_softmax_backward, "Softmax forward (CUDA)");
m.def("fused_scale_mask_softmax_backward", &fused_scale_mask_softmax_backward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_forward", &fused_scale_mask_bias_softmax_forward, m.def("fused_mask_bias_softmax_forward", &fused_mask_bias_softmax_forward,
"Softmax forward (CUDA)"); "Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_backward", &fused_scale_mask_bias_softmax_backward, m.def("fused_mask_bias_softmax_backward", &fused_mask_bias_softmax_backward,
"Softmax forward (CUDA)"); "Softmax forward (CUDA)");
} }
\ No newline at end of file
...@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function): ...@@ -31,18 +31,17 @@ class SoftmaxAffineFunction(torch.autograd.Function):
return grad_input return grad_input
class FusedScaleMaskSoftmaxFunction(torch.autograd.Function): class FusedMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, mask, scale): def forward(ctx, input, mask):
input_ = input.contiguous() input_ = input.contiguous()
mask_ = mask.contiguous() mask_ = mask.contiguous()
ctx.cols = input_.shape[-1] ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1]) ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_softmax_forward( output = fastfold_softmax_cuda.fused_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols, scale) input_, mask_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_) ctx.save_for_backward(output, mask_)
ctx.scale = scale
return output return output
...@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function): ...@@ -52,25 +51,24 @@ class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
output, mask_ = ctx.saved_tensors output, mask_ = ctx.saved_tensors
grad_input = None grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_softmax_backward( grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols, ctx.scale) grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols)
return grad_input.contiguous(), None, None return grad_input.contiguous(), None, None
class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function): class FusedMaskBiasSoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, mask, bias, scale): def forward(ctx, input, mask, bias):
input_ = input.contiguous() input_ = input.contiguous()
mask_ = mask.contiguous() mask_ = mask.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
ctx.cols = input_.shape[-1] ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1]) ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_forward( output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols, scale) input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_, bias_) ctx.save_for_backward(output, mask_, bias_)
ctx.scale = scale
return output return output
...@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function): ...@@ -80,8 +78,8 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
output, mask_, bias_ = ctx.saved_tensors output, mask_, bias_ = ctx.saved_tensors
grad_input = None grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_backward( grad_input = fastfold_softmax_cuda.fused_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols, ctx.scale) grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols)
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
...@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function): ...@@ -91,5 +89,5 @@ class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
softmax = SoftmaxAffineFunction.apply softmax = SoftmaxAffineFunction.apply
scale_mask_softmax = FusedScaleMaskSoftmaxFunction.apply mask_softmax = FusedMaskSoftmaxFunction.apply
scale_mask_bias_softmax = FusedScaleMaskBiasSoftmaxFunction.apply mask_bias_softmax = FusedMaskBiasSoftmaxFunction.apply
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from fastfold.model.fastnn.kernel import scale_mask_softmax, scale_mask_bias_softmax from fastfold.model.fastnn.kernel import mask_softmax, mask_bias_softmax
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from .initializer import glorot_uniform_af from .initializer import glorot_uniform_af
...@@ -160,26 +160,17 @@ class SelfAttention(nn.Module): ...@@ -160,26 +160,17 @@ class SelfAttention(nn.Module):
qkv = self.to_qkv(in_data).chunk(3, dim=-1) qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
# q = self.to_q(in_data) q = q * self.scaling
# k = self.to_k(in_data)
# v = self.to_k(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), [q, k, v])
# q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
# logits += mask
if nonbatched_bias is not None: if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1) # logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1) bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k') bias = rearrange(bias, 'b q k h -> b h q k')
weights = scale_mask_bias_softmax(logits, mask, bias.unsqueeze(1), self.scaling) weights = mask_bias_softmax(logits, mask, bias.unsqueeze(1))
else: else:
weights = scale_mask_softmax(logits, mask, self.scaling) weights = mask_softmax(logits, mask)
# weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment