Unverified Commit 3b046699 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

Triton kernel (#78)

* add triton kernel

* add softmax kernel

* fix bugs in triton kernel

* remove eviction_policy

* refactoring triton kernel for softmax
parent b3af1957
...@@ -15,8 +15,8 @@ class Evoformer(nn.Module): ...@@ -15,8 +15,8 @@ class Evoformer(nn.Module):
def forward(self, node, pair, node_mask, pair_mask): def forward(self, node, pair, node_mask, pair_mask):
node = self.msa_stack(node, pair, node_mask) node = self.msa_stack(node, pair, node_mask)
pair = pair + self.communication(node, node_mask) pair = self.communication(node, node_mask, pair)
node, work = All_to_All_Async.apply(node, 1, 2) node, work = All_to_All_Async.apply(node, 1, 2)
pair = self.pair_stack(pair, pair_mask) pair = self.pair_stack(pair, pair_mask)
node = All_to_All_Async_Opp.apply(node, work, 1, 2) node = All_to_All_Async_Opp.apply(node, work, 1, 2)
return node, pair return node, pair
\ No newline at end of file
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm from .layer_norm import FusedLayerNorm as LayerNorm
from .cuda_native.softmax import softmax, mask_softmax, mask_bias_softmax from .softmax import fused_softmax
__all__ = [ __all__ = [
"bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax", "bias_dropout_add",
"mask_softmax", "mask_bias_softmax" "bias_sigmod_ele",
"bias_ele_dropout_residual",
"LayerNorm",
"fused_softmax",
] ]
\ No newline at end of file
import importlib import importlib
import numbers
import torch import torch
from torch.nn import init
from torch.nn.parameter import Parameter
global fastfold_layer_norm_cuda fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
fastfold_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
...@@ -37,34 +33,3 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -37,34 +33,3 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm, self).__init__()
global fastfold_layer_norm_cuda
if fastfold_layer_norm_cuda is None:
try:
fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
import importlib import importlib
from functools import reduce
from operator import mul
import torch
fastfold_softmax_cuda = importlib.import_module("fastfold_softmax_cuda") fastfold_softmax_cuda = importlib.import_module("fastfold_softmax_cuda")
class SoftmaxAffineFunction(torch.autograd.Function): def softmax_cuda_kernel_wrapper(input_, mask_, bias_, rows, cols):
if bias_ is not None:
@staticmethod output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(input_, mask_, bias_, rows, cols)
def forward(ctx, input): elif mask_ is not None:
input_ = input.contiguous() output = fastfold_softmax_cuda.fused_mask_softmax_forward(input_, mask_, rows, cols)
ctx.cols = input_.shape[-1] else:
ctx.rows = reduce(mul, input.shape[:-1]) output = fastfold_softmax_cuda.forward(input_, rows, cols)
output = fastfold_softmax_cuda.forward(input_, ctx.rows, ctx.cols)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output = ctx.saved_tensors[0]
grad_input = None
grad_input = fastfold_softmax_cuda.backward(grad_output.contiguous(), output,
ctx.rows, ctx.cols)
return grad_input
class FusedMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask):
input_ = input.contiguous()
mask_ = mask.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_)
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols)
return grad_input.contiguous(), None
class FusedMaskBiasSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias):
input_ = input.contiguous()
mask_ = mask.contiguous()
bias_ = bias.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_, bias_)
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_, bias_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols)
grad_input = grad_input.contiguous()
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input.contiguous(), None, grad_bias return output
softmax = SoftmaxAffineFunction.apply def softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, rows, cols):
mask_softmax = FusedMaskSoftmaxFunction.apply if mask_ is not None:
mask_bias_softmax = FusedMaskBiasSoftmaxFunction.apply grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(grad_output, output, mask_, rows, cols)
else:
grad_input = fastfold_softmax_cuda.backward(grad_output, output, rows, cols)
return grad_input
import numbers
import logging
import torch
from torch.nn.parameter import Parameter
_triton_available = True
if _triton_available:
try:
from .triton.layer_norm import LayerNormTritonFunc
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
from .cuda_native.layer_norm import FusedLayerNormAffineFunction
class FusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(FusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, input):
if _triton_available:
return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias,
self.eps)
else:
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
from functools import reduce
from operator import mul
import logging
import torch
_triton_available = True
if _triton_available:
try:
from .triton.softmax import softmax_triton_kernel_wrapper
from .triton.softmax import softmax_grad_triton_kernel_wrapper
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
from .cuda_native.softmax import softmax_cuda_kernel_wrapper
from .cuda_native.softmax import softmax_grad_cuda_kernel_wrapper
class FusedSoftmaxFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask=None, bias=None):
input_ = input.contiguous()
mask_, bias_ = None, None
ctx.use_bias = False
if mask is not None:
mask_ = mask.contiguous()
if bias is not None:
bias_ = bias.contiguous()
ctx.use_bias = True
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
if _triton_available:
output = softmax_triton_kernel_wrapper(input_, mask_, bias_, ctx.rows, ctx.cols)
else:
output = softmax_cuda_kernel_wrapper(input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
output, mask_ = ctx.saved_tensors
if _triton_available:
grad_input = softmax_grad_triton_kernel_wrapper(grad_output, output, mask_, ctx.rows,
ctx.cols)
else:
grad_input = softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, ctx.rows,
ctx.cols)
grad_bias = None
if ctx.use_bias:
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input, None, grad_bias
fused_softmax = FusedSoftmaxFunc.apply
\ No newline at end of file
import torch
import triton
import triton.language as tl
@triton.jit
def _layer_norm_fwd_fused(
Out,
A,
Weight,
Bias,
Mean,
Rstd,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.,).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(
_DA,
_DOut,
_A,
Weight,
Mean,
Rstd,
stride,
NumRows,
NumCols,
eps,
BLOCK_SIZE_N: tl.constexpr,
):
# position of elements processed by this program
pid = tl.program_id(0)
row = pid
A = _A + row * stride
DOut = _DOut + row * stride
DA = _DA + row * stride
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# load data to SRAM
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
_mean1 += a_hat * wdout
_mean2 += wdout
mean1 = tl.sum(_mean1, axis=0) / NumCols
mean2 = 0.
mean2 = tl.sum(_mean2, axis=0) / NumCols
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DA + cols, da, mask=mask)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(
A,
DOut,
Mean,
Var,
DW,
DB,
M,
N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
tl.store(DB + cols, sum_db, mask=cols < N)
class LayerNormTritonFunc(torch.autograd.Function):
def forward(ctx, a_raw, normalized_shape, weight, bias, eps):
# allocate output
a = a_raw.contiguous()
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](
out,
a_arg,
weight,
bias,
mean,
rstd,
a_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.save_for_backward(
a,
weight,
bias,
mean,
rstd,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return out
@staticmethod
def backward(ctx, dout):
assert dout.is_contiguous()
a, weight, bias, mean, var = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
# allocate output
da = torch.empty_like(dout)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = a.reshape(-1, a.shape[-1])
M, N = x_arg.shape
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
_layer_norm_bwd_dx_fused[(M,)](
da,
dout,
a,
weight,
mean,
var,
x_arg.stride(0),
M,
N,
ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
if N > 384:
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 4
BLOCK_SIZE_M = 256
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](a,
dout,
mean,
var,
dweight,
dbias,
M,
N,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps)
return (da, None, dweight, dbias, None)
import torch
import triton
import triton.language as tl
@triton.jit
def _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
if use_bias:
bias = tl.load(bias_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
row += bias
if use_mask:
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
row = tl.where(mask == 0, float("-1e20"), row)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
@triton.jit
def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16: tl.constexpr, use_mask: tl.constexpr):
output_row = tl.load(output_ptrs, mask=col_offsets < n_cols, other=float("-inf"))
d_output_row = tl.load(d_output_ptrs, mask=col_offsets < n_cols, other=float("-inf"))
if is_bf16:
output_row = output_row.to(tl.float32)
d_output_row = d_output_row.to(tl.float32)
row_sum = tl.sum(output_row * d_output_row, axis=0)
d_softmax_output = (d_output_row - row_sum) * output_row
if use_mask:
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
d_softmax_output = tl.where(mask == 0, float(0), d_softmax_output)
tl.store(d_input_ptrs, d_softmax_output, mask=col_offsets < n_cols)
@triton.jit
def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + row_idx * input_row_stride
output_row_ptr = output_ptr + row_idx * output_row_stride
input_ptrs = input_row_ptr + col_offsets
output_ptrs = output_row_ptr + col_offsets
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + (row_idx // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + (row_idx % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask,
use_bias)
@triton.jit
def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + 2 * row_idx * input_row_stride
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
input_ptrs = input_row_ptr + col_offsets
output_ptrs = output_row_ptr + col_offsets
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + ((2 * row_idx) % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask,
use_bias)
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx + 1) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + ((2 * row_idx + 1) % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs + n_cols, output_ptrs + n_cols, mask_ptrs, bias_ptrs, col_offsets,
n_cols, use_mask, use_bias)
@triton.jit
def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_output_row_stride,
output_row_stride, d_input_row_stride, n_cols, n_heads,
BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr,
use_mask: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + row_idx * output_row_stride
d_output_row_ptr = d_output_ptr + row_idx * d_output_row_stride
d_input_row_ptr = d_input_ptr + row_idx * d_input_row_stride
output_ptrs = output_row_ptr + col_offsets
d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + (row_idx // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16, use_mask)
@triton.jit
def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mask_ptr,
d_output_row_stride, output_row_stride, d_input_row_stride,
n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
is_bf16: tl.constexpr, use_mask: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
d_output_row_ptr = d_output_ptr + 2 * row_idx * d_output_row_stride
d_input_row_ptr = d_input_ptr + 2 * row_idx * d_input_row_stride
output_ptrs = output_row_ptr + col_offsets
d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_offsets, n_cols,
is_bf16, use_mask)
mask_ptrs = output_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx + 1) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
_softmax_grad_core(output_ptrs + n_cols, d_output_ptrs + n_cols, d_input_ptrs + n_cols,
mask_ptrs, col_offsets, n_cols, is_bf16, use_mask)
def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols):
y = torch.empty_like(x)
n_heads = x.shape[2]
num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
if BLOCK_SIZE >= 1024:
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
_dispatch_kernel = softmax_mask_bias_kernel
_grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_mask_bias_kernel_two_rows
_grid = (n_rows // 2,)
_dispatch_kernel[_grid](
y,
x,
mask,
bias,
x.stride(-2),
y.stride(-2),
n_cols,
n_heads,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
use_mask=(mask != None),
use_bias=(bias != None),
)
return y
def softmax_grad_triton_kernel_wrapper(grad_output, output, mask, n_rows, n_cols):
grad_input = torch.empty_like(grad_output)
n_heads = output.shape[2]
num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
if BLOCK_SIZE >= 1024:
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
is_bf16 = (output.dtype == torch.bfloat16)
_dispatch_kernel = softmax_mask_grad_kernel
_grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_mask_grad_kernel_two_rows
_grid = (n_rows // 2,)
_dispatch_kernel[_grid](
grad_output,
output,
grad_input,
mask,
grad_output.stride(-2),
output.stride(-2),
grad_output.stride(-2),
n_cols,
n_heads,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
is_bf16=is_bf16,
use_mask=(mask != None),
)
return grad_input
...@@ -17,8 +17,7 @@ import math ...@@ -17,8 +17,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from torch.nn import LayerNorm
from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention from fastfold.model.fastnn.ops import ChunkMSARowAttentionWithPairBias, ChunkTransition, SelfAttention, GlobalAttention, Transition, ChunkMSAColumnGlobalAttention
from fastfold.model.fastnn.kernel import bias_dropout_add from fastfold.model.fastnn.kernel import bias_dropout_add
......
...@@ -18,8 +18,8 @@ import torch.nn.functional as F ...@@ -18,8 +18,8 @@ import torch.nn.functional as F
import math import math
from einops import rearrange from einops import rearrange
from typing import Tuple from typing import Tuple
from fastfold.model.fastnn.kernel import mask_softmax, mask_bias_softmax
from fastfold.model.fastnn.kernel import LayerNorm from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.model.fastnn.kernel import fused_softmax
from .initializer import glorot_uniform_af from .initializer import glorot_uniform_af
...@@ -140,6 +140,7 @@ class OutProductMean(nn.Module): ...@@ -140,6 +140,7 @@ class OutProductMean(nn.Module):
self.n_feat_proj = n_feat_proj self.n_feat_proj = n_feat_proj
def forward(self, M, M_mask, Z_raw): def forward(self, M, M_mask, Z_raw):
Z = torch.empty_like(Z_raw)
M = self.layernormM(M) M = self.layernormM(M)
right_act = self.linear_b(M) right_act = self.linear_b(M)
right_act_all, work = gather_async(right_act, dim=2) right_act_all, work = gather_async(right_act, dim=2)
...@@ -165,9 +166,9 @@ class OutProductMean(nn.Module): ...@@ -165,9 +166,9 @@ class OutProductMean(nn.Module):
O = rearrange(O, 'b i j d e -> b i j (d e)') O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O) O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :] norm0 = norm[:, ax:ax + chunk_size, :, :]
Z_raw[:, ax:ax + chunk_size, :, :] += O / norm0 Z[:, ax:ax + chunk_size, :, :] += O / norm0
return Z_raw return Z
def inplace(self, M, M_mask, Z_raw): def inplace(self, M, M_mask, Z_raw):
...@@ -317,9 +318,9 @@ class SelfAttention(nn.Module): ...@@ -317,9 +318,9 @@ class SelfAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None: if nonbatched_bias is not None:
weights = mask_bias_softmax(logits, mask_part, bias.unsqueeze(1)) weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else: else:
weights = mask_softmax(logits, mask) weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
...@@ -1168,7 +1169,7 @@ class GlobalAttention(nn.Module): ...@@ -1168,7 +1169,7 @@ class GlobalAttention(nn.Module):
logits = torch.matmul(q, k.transpose(-1, -2)) logits = torch.matmul(q, k.transpose(-1, -2))
weights = mask_softmax(logits, mask_part) weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v) weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)") weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment