Unverified Commit 3b046699 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

Triton kernel (#78)

* add triton kernel

* add softmax kernel

* fix bugs in triton kernel

* remove eviction_policy

* refactoring triton kernel for softmax
parent b3af1957
......@@ -15,8 +15,8 @@ class Evoformer(nn.Module):
def forward(self, node, pair, node_mask, pair_mask):
node = self.msa_stack(node, pair, node_mask)
pair = pair + self.communication(node, node_mask)
pair = self.communication(node, node_mask, pair)
node, work = All_to_All_Async.apply(node, 1, 2)
pair = self.pair_stack(pair, pair_mask)
node = All_to_All_Async_Opp.apply(node, work, 1, 2)
return node, pair
\ No newline at end of file
return node, pair
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm
from .cuda_native.softmax import softmax, mask_softmax, mask_bias_softmax
from .layer_norm import FusedLayerNorm as LayerNorm
from .softmax import fused_softmax
__all__ = [
"bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax",
"mask_softmax", "mask_bias_softmax"
"bias_dropout_add",
"bias_sigmod_ele",
"bias_ele_dropout_residual",
"LayerNorm",
"fused_softmax",
]
\ No newline at end of file
import importlib
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
global fastfold_layer_norm_cuda
fastfold_layer_norm_cuda = None
fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
class FusedLayerNormAffineFunction(torch.autograd.Function):
......@@ -37,34 +33,3 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm, self).__init__()
global fastfold_layer_norm_cuda
if fastfold_layer_norm_cuda is None:
try:
fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
import importlib
from functools import reduce
from operator import mul
import torch
fastfold_softmax_cuda = importlib.import_module("fastfold_softmax_cuda")
class SoftmaxAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input_ = input.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.forward(input_, ctx.rows, ctx.cols)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output = ctx.saved_tensors[0]
grad_input = None
grad_input = fastfold_softmax_cuda.backward(grad_output.contiguous(), output,
ctx.rows, ctx.cols)
return grad_input
class FusedMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask):
input_ = input.contiguous()
mask_ = mask.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_)
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols)
return grad_input.contiguous(), None
class FusedMaskBiasSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias):
input_ = input.contiguous()
mask_ = mask.contiguous()
bias_ = bias.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_, bias_)
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_, bias_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols)
grad_input = grad_input.contiguous()
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
def softmax_cuda_kernel_wrapper(input_, mask_, bias_, rows, cols):
if bias_ is not None:
output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(input_, mask_, bias_, rows, cols)
elif mask_ is not None:
output = fastfold_softmax_cuda.fused_mask_softmax_forward(input_, mask_, rows, cols)
else:
output = fastfold_softmax_cuda.forward(input_, rows, cols)
return grad_input.contiguous(), None, grad_bias
return output
softmax = SoftmaxAffineFunction.apply
mask_softmax = FusedMaskSoftmaxFunction.apply
mask_bias_softmax = FusedMaskBiasSoftmaxFunction.apply
def softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, rows, cols):
if mask_ is not None:
grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(grad_output, output, mask_, rows, cols)
else:
grad_input = fastfold_softmax_cuda.backward(grad_output, output, rows, cols)
return grad_input
import numbers
import logging
import torch
from torch.nn.parameter import Parameter
_triton_available = True
if _triton_available:
try:
from .triton.layer_norm import LayerNormTritonFunc
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
from .cuda_native.layer_norm import FusedLayerNormAffineFunction
class FusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(FusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, input):
if _triton_available:
return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias,
self.eps)
else:
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
from functools import reduce
from operator import mul
import logging
import torch
_triton_available = True
if _triton_available:
try:
from .triton.softmax import softmax_triton_kernel_wrapper
from .triton.softmax import softmax_grad_triton_kernel_wrapper
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
from .cuda_native.softmax import softmax_cuda_kernel_wrapper
from .cuda_native.softmax import softmax_grad_cuda_kernel_wrapper
class FusedSoftmaxFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask=None, bias=None):
input_ = input.contiguous()
mask_, bias_ = None, None
ctx.use_bias = False
if mask is not None:
mask_ = mask.contiguous()
if bias is not None:
bias_ = bias.contiguous()
ctx.use_bias = True
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
if _triton_available:
output = softmax_triton_kernel_wrapper(input_, mask_, bias_, ctx.rows, ctx.cols)
else:
output = softmax_cuda_kernel_wrapper(input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
output, mask_ = ctx.saved_tensors
if _triton_available:
grad_input = softmax_grad_triton_kernel_wrapper(grad_output, output, mask_, ctx.rows,
ctx.cols)
else:
grad_input = softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, ctx.rows,
ctx.cols)
grad_bias = None
if ctx.use_bias:
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input, None, grad_bias
fused_softmax = FusedSoftmaxFunc.apply
\ No newline at end of file
import torch
import triton
import triton.language as tl
@triton.jit
def _layer_norm_fwd_fused(
Out,
A,
Weight,
Bias,
Mean,
Rstd,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.,).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(
_DA,
_DOut,
_A,
Weight,
Mean,
Rstd,
stride,
NumRows,
NumCols,
eps,
BLOCK_SIZE_N: tl.constexpr,
):
# position of elements processed by this program
pid = tl.program_id(0)
row = pid
A = _A + row * stride
DOut = _DOut + row * stride
DA = _DA + row * stride
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# load data to SRAM
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
_mean1 += a_hat * wdout
_mean2 += wdout
mean1 = tl.sum(_mean1, axis=0) / NumCols
mean2 = 0.
mean2 = tl.sum(_mean2, axis=0) / NumCols
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DA + cols, da, mask=mask)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(
A,
DOut,
Mean,
Var,
DW,
DB,
M,
N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
tl.store(DB + cols, sum_db, mask=cols < N)
class LayerNormTritonFunc(torch.autograd.Function):
def forward(ctx, a_raw, normalized_shape, weight, bias, eps):
# allocate output
a = a_raw.contiguous()
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](
out,
a_arg,
weight,
bias,
mean,
rstd,
a_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.save_for_backward(
a,
weight,
bias,
mean,
rstd,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return out
@staticmethod
def backward(ctx, dout):
assert dout.is_contiguous()
a, weight, bias, mean, var = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
# allocate output
da = torch.empty_like(dout)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = a.reshape(-1, a.shape[-1])
M, N = x_arg.shape
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
_layer_norm_bwd_dx_fused[(M,)](
da,
dout,
a,
weight,
mean,
var,
x_arg.stride(0),
M,
N,
ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
if N > 384:
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 4
BLOCK_SIZE_M = 256
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](a,
dout,
mean,
var,
dweight,
dbias,
M,
N,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps)
return (da, None, dweight, dbias, None)
import torch
import triton
import triton.language as tl
@triton.jit
def _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
if use_bias:
bias = tl.load(bias_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
row += bias
if use_mask:
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
row = tl.where(mask == 0, float("-1e20"), row)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
@triton.jit
def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16: tl.constexpr, use_mask: tl.constexpr):
output_row = tl.load(output_ptrs, mask=col_offsets < n_cols, other=float("-inf"))
d_output_row = tl.load(d_output_ptrs, mask=col_offsets < n_cols, other=float("-inf"))
if is_bf16:
output_row = output_row.to(tl.float32)
d_output_row = d_output_row.to(tl.float32)
row_sum = tl.sum(output_row * d_output_row, axis=0)
d_softmax_output = (d_output_row - row_sum) * output_row
if use_mask:
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
d_softmax_output = tl.where(mask == 0, float(0), d_softmax_output)
tl.store(d_input_ptrs, d_softmax_output, mask=col_offsets < n_cols)
@triton.jit
def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + row_idx * input_row_stride
output_row_ptr = output_ptr + row_idx * output_row_stride
input_ptrs = input_row_ptr + col_offsets
output_ptrs = output_row_ptr + col_offsets
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + (row_idx // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + (row_idx % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask,
use_bias)
@triton.jit
def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + 2 * row_idx * input_row_stride
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
input_ptrs = input_row_ptr + col_offsets
output_ptrs = output_row_ptr + col_offsets
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + ((2 * row_idx) % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask,
use_bias)
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx + 1) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + ((2 * row_idx + 1) % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs + n_cols, output_ptrs + n_cols, mask_ptrs, bias_ptrs, col_offsets,
n_cols, use_mask, use_bias)
@triton.jit
def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_output_row_stride,
output_row_stride, d_input_row_stride, n_cols, n_heads,
BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr,
use_mask: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + row_idx * output_row_stride
d_output_row_ptr = d_output_ptr + row_idx * d_output_row_stride
d_input_row_ptr = d_input_ptr + row_idx * d_input_row_stride
output_ptrs = output_row_ptr + col_offsets
d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + (row_idx // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16, use_mask)
@triton.jit
def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mask_ptr,
d_output_row_stride, output_row_stride, d_input_row_stride,
n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
is_bf16: tl.constexpr, use_mask: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
d_output_row_ptr = d_output_ptr + 2 * row_idx * d_output_row_stride
d_input_row_ptr = d_input_ptr + 2 * row_idx * d_input_row_stride
output_ptrs = output_row_ptr + col_offsets
d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16, use_mask)
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx + 1) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs + n_cols, d_output_ptrs + n_cols, d_input_ptrs + n_cols,
mask_ptrs, col_offsets, n_cols, is_bf16, use_mask)
def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols):
y = torch.empty_like(x)
n_heads = x.shape[2]
num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
if BLOCK_SIZE >= 1024:
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
_dispatch_kernel = softmax_mask_bias_kernel
_grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_mask_bias_kernel_two_rows
_grid = (n_rows // 2,)
_dispatch_kernel[_grid](
y,
x,
mask,
bias,
x.stride(-2),
y.stride(-2),
n_cols,
n_heads,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
use_mask=(mask != None),
use_bias=(bias != None),
)
return y
def softmax_grad_triton_kernel_wrapper(grad_output, output, mask, n_rows, n_cols):
grad_input = torch.empty_like(grad_output)
n_heads = output.shape[2]
num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
if BLOCK_SIZE >= 1024:
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
is_bf16 = (output.dtype == torch.bfloat16)
_dispatch_kernel = softmax_mask_grad_kernel
_grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_mask_grad_kernel_two_rows
_grid = (n_rows // 2,)
_dispatch_kernel[_grid](
grad_output,
output,
grad_input,
mask,
grad_output.stride(-2),
output.stride(-2),
grad_output.stride(-2),
n_cols,
n_heads,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
is_bf16=is_bf16,
use_mask=(mask != None),
)
return grad_input
......@@ -17,8 +17,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# from fastfold.model.fastnn.kernel import LayerNorm
from torch.nn import LayerNorm
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention
from fastfold.model.fastnn.kernel import bias_dropout_add
......
......@@ -18,8 +18,8 @@ import torch.nn.functional as F
import math
from einops import rearrange
from typing import Tuple
from fastfold.model.fastnn.kernel import mask_softmax, mask_bias_softmax
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.model.fastnn.kernel import fused_softmax
from .initializer import glorot_uniform_af
......@@ -140,6 +140,7 @@ class OutProductMean(nn.Module):
self.n_feat_proj = n_feat_proj
def forward(self, M, M_mask, Z_raw):
Z = torch.empty_like(Z_raw)
M = self.layernormM(M)
right_act = self.linear_b(M)
right_act_all, work = gather_async(right_act, dim=2)
......@@ -165,9 +166,9 @@ class OutProductMean(nn.Module):
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z_raw[:, ax:ax + chunk_size, :, :] += O / norm0
Z[:, ax:ax + chunk_size, :, :] += O / norm0
return Z_raw
return Z
def inplace(self, M, M_mask, Z_raw):
......@@ -317,9 +318,9 @@ class SelfAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
weights = mask_bias_softmax(logits, mask_part, bias.unsqueeze(1))
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else:
weights = mask_softmax(logits, mask)
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
......@@ -1168,7 +1169,7 @@ class GlobalAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2))
weights = mask_softmax(logits, mask_part)
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment