Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from fastfold.habana.distributed import gather, scatter
from .initializer import glorot_uniform_af
from .kernel import bias_sigmod_ele
from fastfold.habana.distributed import gather, scatter
from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias
CHUNK_SIZE = None
DEBUG = False
def set_chunk_size(chunk_size):
global CHUNK_SIZE
CHUNK_SIZE = chunk_size
def get_chunk_size():
global CHUNK_SIZE
return CHUNK_SIZE
class DropoutRowwise(nn.Module):
def __init__(self, p):
super(DropoutRowwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class DropoutColumnwise(nn.Module):
def __init__(self, p):
super(DropoutColumnwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class Transition(nn.Module):
def __init__(self, d, n=4):
super(Transition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
x = self.norm(src)
x = self.linear2(F.relu(self.linear1(x)))
return src + x
class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
super(OutProductMean, self).__init__()
self.layernormM = LayerNorm(n_feat)
self.linear_a = Linear(n_feat, n_feat_proj)
self.linear_b = Linear(n_feat, n_feat_proj)
self.o_linear = Linear(n_feat_proj * n_feat_proj,
n_feat_out,
initializer='zero',
use_bias=True)
def forward(self, M, M_mask, Z_raw):
Z = torch.empty_like(Z_raw)
M = self.layernormM(M)
left_act = self.linear_a(M)
right_act = self.linear_b(M)
right_act_all = gather(right_act, dim=2)
M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act
right_act_all = M_mask * right_act_all
norm = torch.einsum('...ab,...ad->...bd',
M_mask_col.squeeze(-1).squeeze(0),
M_mask.squeeze(-1).squeeze(0)).unsqueeze(-1).unsqueeze(0)
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
out = []
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
# O = torch.einsum('sid,sje->ijde', left_act_part.squeeze(0), right_act_all.squeeze(0))
# O = rearrange(O, 'i j d e -> i j (d e)')
left_shape = left_act_part.shape
right_shape = right_act_all.shape
left_act_part = left_act_part.reshape(left_shape[0], left_shape[1], left_shape[2]*left_shape[3])
right_act_all = right_act_all.reshape(right_shape[0], right_shape[1], right_shape[2]*right_shape[3])
# O = torch.einsum('...ab,...ad->...bd', left_act_part.squeeze(0), right_act_all.squeeze(0))
O = torch.matmul(left_act_part.squeeze(0).transpose(1, 0), right_act_all.squeeze(0))
O = O.reshape(left_shape[2], left_shape[3], right_shape[2], right_shape[3]).transpose(-2, -3)
O = O.reshape(O.shape[0], O.shape[1], O.shape[2]*O.shape[3])
O = O.unsqueeze(0)
out.append(self.o_linear(O))
Z = torch.cat(out, dim=1)
Z /= (1e-3 + norm)
return Z + Z_raw
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
feature_in: int,
feature_out: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
):
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
self.use_bias = use_bias
if initializer == 'linear':
glorot_uniform_af(self.weight, gain=1.0)
elif initializer == 'relu':
glorot_uniform_af(self.weight, gain=2.0)
elif initializer == 'zeros':
nn.init.zeros_(self.weight)
if self.use_bias:
with torch.no_grad():
self.bias.fill_(bias_init)
class SelfAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
super(SelfAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.gating = gating
self.last_bias_fuse = last_bias_fuse
self.scaling = self.c**(-0.5)
self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear', use_bias=False)
# self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if gating:
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
self.o_linear = Linear(n_head * c,
out_dim,
initializer='zero',
use_bias=(not last_bias_fuse))
def forward(self, in_data, mask, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = []
for ax in range(0, para_dim, chunk_size):
in_data_part = in_data[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
qkv = self.to_qkv(in_data_part).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
# q = self.to_q(in_data_part)
# k = self.to_k(in_data_part)
# v = self.to_v(in_data_part)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
# weights = torch.softmax(logits, dim=-1)
mask00 = (1e9 * (mask_part - 1))[..., :, None, None, :]
if nonbatched_bias is not None:
weights = fused_softmax_bias(logits, mask00, nonbatched_bias.unsqueeze(1), -1)
else:
weights = fused_softmax(logits, mask00, -1)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output.append(self.o_linear(weighted_avg))
output = torch.cat(output, dim=1)
return output
class GlobalAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim):
super(GlobalAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.scaling = self.c**(-0.5)
self.eps = 1e-10
self.inf = 1e9
self.to_q = Linear(qkv_dim, c * self.n_head, use_bias=False)
self.to_kv = Linear(qkv_dim, 2 * c, initializer="linear", use_bias=False)
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer="zero", use_bias=False)
self.o_linear = Linear(n_head * c, out_dim, initializer="zero")
def forward(self, m, mask):
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1),
dim=-2) / (torch.sum(mask_part, dim=-1)[..., None] + self.eps)
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
weights = torch.softmax(logits, dim=-1)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias,
weighted_avg.unsqueeze(-2))
output.append(self.o_linear(weighted_avg))
m = torch.cat(output, dim=1)
return m
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from fastfold.habana.distributed import col_to_row, gather, row_to_col, scatter
from .kernel import bias_dropout_add, bias_ele_dropout_residual
from .ops import Linear, SelfAttention, Transition
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class TriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_projection = Linear(d_pair, c)
self.right_projection = Linear(d_pair, c)
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
left_proj_act = self.left_projection(Z)
right_proj_act = self.right_projection(Z)
left_proj_act = Z_mask.unsqueeze(-1) * left_proj_act
right_proj_act = Z_mask.unsqueeze(-1) * right_proj_act
left_proj_act *= torch.sigmoid(self.left_gate(Z))
right_proj_act *= torch.sigmoid(self.right_gate(Z))
right_proj_act = gather(right_proj_act.contiguous(), dim=1)
g = torch.sigmoid(self.output_gate(Z))
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 0, 1)),
permute_final_dims(right_proj_act, (2, 1, 0)),
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_projection = Linear(d_pair, c)
self.right_projection = Linear(d_pair, c)
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
left_proj_act = self.left_projection(Z)
right_proj_act = self.right_projection(Z)
left_proj_act = Z_mask.unsqueeze(-1) * left_proj_act
right_proj_act = Z_mask.unsqueeze(-1) * right_proj_act
left_proj_act *= torch.sigmoid(self.left_gate(Z))
right_proj_act *= torch.sigmoid(self.right_gate(Z))
left_proj_act = gather(left_proj_act.contiguous(), dim=2)
g = torch.sigmoid(self.output_gate(Z))
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 1, 0)),
permute_final_dims(right_proj_act, (2, 0, 1)),
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
b = F.linear(Z, self.linear_b_weights)
b = gather(b, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
Z = self.attention(Z, Z_mask, b)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = F.linear(Z, self.linear_b_weights)
b = gather(b, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
Z = self.attention(Z, Z_mask, b)
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class PairStack(nn.Module):
def __init__(self, d_pair, p_drop=0.25):
super(PairStack, self).__init__()
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
self.PairTransition = Transition(d=d_pair)
def forward(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1)
pair_mask_col = scatter(pair_mask, dim=2)
pair = self.TriangleMultiplicationOutgoing(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleMultiplicationIncoming(pair, pair_mask_col)
pair = col_to_row(pair)
pair = self.TriangleAttentionStartingNode(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleAttentionEndingNode(pair, pair_mask_col)
pair = self.PairTransition(pair)
pair = col_to_row(pair)
return pair
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fastfold.habana.fastnn import EvoformerStack, ExtraMSAStack
#from fastfold.model.fastnn.embedders import TemplateEmbedder
#from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
#from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias)
def copy_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
if model_fast.use_bias:
model_fast.bias.copy_(model_ori.bias)
def copy_native_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
try:
model_fast.bias.copy_(model_ori.bias)
except:
pass
def copy_kv_linear(model_fast, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0))
def copy_qkv_linear(model_fast, ori_q, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_q.weight, ori_k.weight, ori_v.weight), dim=0))
def copy_attention(model_fast, model_ori):
copy_qkv_linear(model_fast.to_qkv, model_ori.linear_q, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_left_right(model_fast, ori_left, ori_right):
model_fast.weight.copy_(torch.cat((ori_left.weight, ori_right.weight), dim=0))
model_fast.bias.copy_(torch.cat((ori_left.bias, ori_right.bias), dim=0))
def copy_transition(model_fast, model_ori):
copy_layernorm(model_fast.norm, model_ori.layer_norm)
copy_linear(model_fast.linear1, model_ori.linear_1)
copy_linear(model_fast.linear2, model_ori.linear_2)
def copy_triangle(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in)
copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out)
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_linear(model_fast.output_projection, model_ori.linear_z)
model_fast.output_bias.copy_(model_ori.linear_z.bias)
copy_linear(model_fast.left_projection, model_ori.linear_a_p)
copy_linear(model_fast.right_projection, model_ori.linear_b_p)
copy_linear(model_fast.left_gate, model_ori.linear_a_g)
copy_linear(model_fast.right_gate, model_ori.linear_b_g)
def copy_triangle_att(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm)
model_fast.linear_b_weights = model_ori.linear.weight
copy_attention(model_fast.attention, model_ori.mha)
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_native_att(model_fast, model_ori):
copy_native_linear(model_fast.linear_q, model_ori.linear_q)
copy_native_linear(model_fast.linear_k, model_ori.linear_k)
copy_native_linear(model_fast.linear_v, model_ori.linear_v)
copy_native_linear(model_fast.linear_o, model_ori.linear_o)
if model_ori.gating:
copy_native_linear(model_fast.linear_g, model_ori.linear_g)
def copy_evoformer_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m)
copy_layernorm(block_fast.msa.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z)
copy_attention(block_fast.msa.MSARowAttentionWithPairBias.attention, block_ori.msa_att_row.mha)
block_fast.msa.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight)
block_fast.msa.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention
copy_layernorm(block_fast.msa.MSAColumnAttention.layernormM,
block_ori.msa_att_col._msa_att.layer_norm_m)
copy_attention(block_fast.msa.MSAColumnAttention.attention, block_ori.msa_att_col._msa_att.mha)
# MSATransition
copy_transition(block_fast.msa.MSATransition, block_ori.core.msa_transition)
# communication
copy_layernorm(block_fast.communication.layernormM,
block_ori.core.outer_product_mean.layer_norm)
copy_linear(block_fast.communication.linear_a, block_ori.core.outer_product_mean.linear_1)
copy_linear(block_fast.communication.linear_b, block_ori.core.outer_product_mean.linear_2)
copy_linear(block_fast.communication.o_linear, block_ori.core.outer_product_mean.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.pair.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.pair.TriangleAttentionStartingNode, block_ori.core.tri_att_start)
copy_triangle_att(block_fast.pair.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair.PairTransition, block_ori.core.pair_transition)
def copy_global_attention(model_fast, model_ori):
copy_linear(model_fast.to_q, model_ori.linear_q)
copy_kv_linear(model_fast.to_kv, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_extra_msa_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m,
)
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z,
)
copy_attention(
block_fast.msa_stack.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha,
)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention
copy_layernorm(
block_fast.msa_stack.MSAColumnAttention.layernormM,
block_ori.msa_att_col.layer_norm_m,
)
copy_global_attention(
block_fast.msa_stack.MSAColumnAttention.global_attention,
block_ori.msa_att_col.global_attention,
)
# MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition)
# communication
comm_model = (
block_ori.core.
outer_product_mean # if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm(block_fast.communication.layernormM, comm_model.layer_norm)
copy_linear(block_fast.communication.linear_a, comm_model.linear_1)
copy_linear(block_fast.communication.linear_b, comm_model.linear_2)
copy_linear(block_fast.communication.o_linear, comm_model.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(
block_fast.pair_stack.TriangleAttentionStartingNode,
block_ori.core.tri_att_start,
)
copy_triangle_att(block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition)
def copy_template_pair_stack_para(block_fast, block_ori):
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.TriangleMultiplicationOutgoing, block_ori.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.TriangleMultiplicationIncoming, block_ori.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.TriangleAttentionStartingNode, block_ori.tri_att_start)
copy_triangle_att(block_fast.TriangleAttentionEndingNode, block_ori.tri_att_end)
copy_transition(block_fast.PairTransition, block_ori.pair_transition)
def copy_template_pair_block_para(fast_module, target_module):
with torch.no_grad():
for ori_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_template_pair_stack_para(fast_block, ori_block)
if ori_block.training == False:
fast_block.eval()
def copy_template_para(block_fast, block_ori):
# TemplateAngleEmbedder
copy_linear(block_fast.template_angle_embedder.linear_1,
block_ori.template_angle_embedder.linear_1)
copy_linear(block_fast.template_angle_embedder.linear_2,
block_ori.template_angle_embedder.linear_2)
# TemplatePairEmbedder
copy_linear(block_fast.template_pair_embedder.linear, block_ori.template_pair_embedder.linear)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack, block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# TemplatePointwiseAttention
copy_native_att(block_fast.template_pointwise_att.mha, block_ori.template_pointwise_att.mha)
def copy_template_multimer_para(block_fast, block_ori):
# TemplatePairEmbedderMultimer
copy_linear(block_fast.template_pair_embedder.dgram_linear,
block_ori.template_pair_embedder.dgram_linear)
copy_linear(block_fast.template_pair_embedder.aatype_linear_1,
block_ori.template_pair_embedder.aatype_linear_1)
copy_linear(block_fast.template_pair_embedder.aatype_linear_2,
block_ori.template_pair_embedder.aatype_linear_2)
copy_layernorm(block_fast.template_pair_embedder.query_embedding_layer_norm,
block_ori.template_pair_embedder.query_embedding_layer_norm)
copy_linear(block_fast.template_pair_embedder.query_embedding_linear,
block_ori.template_pair_embedder.query_embedding_linear)
copy_linear(block_fast.template_pair_embedder.pseudo_beta_mask_linear,
block_ori.template_pair_embedder.pseudo_beta_mask_linear)
copy_linear(block_fast.template_pair_embedder.x_linear,
block_ori.template_pair_embedder.x_linear)
copy_linear(block_fast.template_pair_embedder.y_linear,
block_ori.template_pair_embedder.y_linear)
copy_linear(block_fast.template_pair_embedder.z_linear,
block_ori.template_pair_embedder.z_linear)
copy_linear(block_fast.template_pair_embedder.backbone_mask_linear,
block_ori.template_pair_embedder.backbone_mask_linear)
# TemplateSingleEmbedderMultimer
copy_linear(block_fast.template_single_embedder.template_single_embedder,
block_ori.template_single_embedder.template_single_embedder)
copy_linear(block_fast.template_single_embedder.template_projector,
block_ori.template_single_embedder.template_projector)
# TemplatePairStack
copy_template_pair_block_para(block_fast.template_pair_stack, block_ori.template_pair_stack)
copy_layernorm(block_fast.template_pair_stack.layer_norm,
block_ori.template_pair_stack.layer_norm)
# linear_t
copy_linear(block_fast.linear_t, block_ori.linear_t)
def inject_evoformer(model):
with torch.no_grad():
target_module = model.evoformer
fast_module = EvoformerStack(
c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
c_s=target_module.linear.out_features,
no_blocks=len(target_module.blocks),
blocks_per_ckpt=target_module.blocks_per_ckpt,
clear_cache_between_blocks=target_module.clear_cache_between_blocks,
is_multimer=target_module.blocks[0].is_multimer,
)
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_evoformer_para(fast_block, target_block)
if target_module.training == False:
fast_module.eval()
copy_linear(fast_module.linear, target_module.linear)
model.evoformer = fast_module
def inject_extramsa(model):
with torch.no_grad():
target_module = model.extra_msa_stack
fast_module = ExtraMSAStack(
c_m=target_module.blocks[0].msa_att_row.c_in,
c_z=target_module.blocks[0].msa_att_row.c_z,
no_blocks=len(target_module.blocks),
blocks_per_ckpt=1,
clear_cache_between_blocks=target_module.clear_cache_between_blocks,
is_multimer=target_module.blocks[0].is_multimer,
)
for target_block, fast_block in zip(target_module.blocks, fast_module.blocks):
copy_extra_msa_para(fast_block, target_block)
if target_module.training == False:
fast_module.eval()
model.extra_msa_stack = fast_module
def inject_template(model):
with torch.no_grad():
if model.evoformer.blocks[0].is_multimer:
target_module = model.template_embedder
fast_module = TemplateEmbedderMultimer(config=model.template_embedder.config)
copy_template_multimer_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
else:
target_module = model.template_embedder
fast_module = TemplateEmbedder(config=model.template_embedder.config)
copy_template_para(fast_module, target_module)
if target_module.training == False:
fast_module.eval()
model.template_embedder = fast_module
def inject_embedder(model):
if model.evoformer.blocks[0].is_multimer:
return
# recycle embedder
with torch.no_grad():
target_module = model.recycling_embedder
fast_module = RecyclingEmbedder(c_m=target_module.c_m,
c_z=target_module.c_z,
min_bin=target_module.min_bin,
max_bin=target_module.max_bin,
no_bins=target_module.no_bins,
inf=target_module.inf)
copy_native_linear(fast_module.linear, target_module.linear)
copy_layernorm(fast_module.layer_norm_m, target_module.layer_norm_m)
copy_layernorm(fast_module.layer_norm_z, target_module.layer_norm_z)
if target_module.training == False:
fast_module.eval()
model.recycling_embedder = fast_module
# input embedder
with torch.no_grad():
target_module = model.input_embedder
fast_module = InputEmbedder(
tf_dim=target_module.tf_dim,
msa_dim=target_module.msa_dim,
c_z=target_module.c_z,
c_m=target_module.c_m,
relpos_k=target_module.relpos_k,
)
copy_linear(fast_module.linear_tf_z_i, target_module.linear_tf_z_i)
copy_linear(fast_module.linear_tf_z_j, target_module.linear_tf_z_j)
copy_linear(fast_module.linear_tf_m, target_module.linear_tf_m)
copy_linear(fast_module.linear_msa_m, target_module.linear_msa_m)
copy_linear(fast_module.linear_relpos, target_module.linear_relpos)
if target_module.training == False:
fast_module.eval()
model.input_embedder = fast_module
def inject_habana(model):
inject_evoformer(model)
inject_extramsa(model)
#inject_template(model)
#inject_embedder(model)
return model
from .msa import MSACore, ExtraMSACore, ExtraMSABlock, ExtraMSAStack
from .ops import OutProductMean, set_chunk_size
from .triangle import PairCore
from .evoformer import Evoformer, EvoformerStack
from .template import TemplatePairBlock, TemplatePairStack
__all__ = [
'MSACore', 'OutProductMean', 'PairCore', 'set_chunk_size',
'TemplatePairBlock', 'TemplatePairStack',
'ExtraMSACore', 'ExtraMSABlock', 'ExtraMSAStack',
'Evoformer', 'EvoformerStack',
]
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple
from functools import partial
from fastfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from fastfold.model.fastnn.ops import Linear
from fastfold.utils.tensor_utils import one_hot
from fastfold.model.fastnn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb += tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
inplace=False
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
if isinstance(chunk_size, int) and 1 <= chunk_size <= 4:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device='cpu')
else:
t = torch.empty((n_templ, z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
tt = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
chunk=chunk_size,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype).to(z.device)
tt = self.template_pair_embedder(tt)
# single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
# [*, S_t, N, N, C_z]
if inplace:
tt = [tt]
t[i] = self.template_pair_stack.inplace(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)[0].to(t.device)
else:
t[i] = self.template_pair_stack(
tt,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
).to(t.device)
del tt, single_template_feats
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, N, N, C_z]
if inplace:
z = self.template_pointwise_att.inplace(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
else:
z = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
ret = {}
ret["template_pair_embedding"] = z
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
return ret
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, initializer="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, initializer="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, initializer="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, Dict
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import dgram_from_positions
from fastfold.model.fastnn.ops import Linear, LayerNorm
from fastfold.model.fastnn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super().__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super().__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_features = self.template_single_embedder(
template_features
)
template_features = torch.nn.functional.relu(
template_features
)
template_features = self.template_projector(
template_features,
)
out["template_single_embedding"] = (
template_features
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
inplace
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
template_pair_embeddings = torch.zeros((z.shape[0], z.shape[1], 64), dtype=z.dtype, device=z.device)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_embedding = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
if not inplace:
# [*, S_t, N, N, C_z]
template_pair_embeddings = template_pair_embeddings + self.template_pair_stack(
pair_embedding,
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
).squeeze(0)
else:
# [*, S_t, N, N, C_z]
template_pair_embeddings += self.template_pair_stack.inplace(
[pair_embedding],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].squeeze(0)
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, N, N, C_z]
template_pair_embeddings = template_pair_embeddings / n_templ
template_pair_embeddings = torch.nn.functional.relu(template_pair_embeddings)
template_pair_embeddings = self.linear_t(template_pair_embeddings)
template_embeds["template_pair_embedding"] = template_pair_embeddings
return template_embeds
from typing import Optional, Tuple
from functools import partial
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter, col_to_row
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.utils.checkpointing import checkpoint_blocks
class Evoformer(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer: bool=False):
super(Evoformer, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa = MSACore(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair = PairCore(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m[0] = scatter(m[0], dim=2)
else:
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
z = self.communication.inplace(m[0], msa_mask, z)
m[0] = col_to_row(m[0])
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask)
if self.last_block:
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
if self.is_multimer:
m[0] = gather(m[0], dim=1)
else:
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :]
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_s: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = Evoformer(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b.inplace,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[0][..., 0, :, :])
return m, z, s
import math
import numpy as np
import torch.nn as nn
def glorot_uniform_af(x, gain=1.0):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in, fan_out = x.shape[-2:]
if len(x.shape) > 2:
receptive_field_size = np.prod(x.shape[:-2])
fan_in *= receptive_field_size
fan_out *= receptive_field_size
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(x, -dev, dev)
return x
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .layer_norm import FusedLayerNorm as LayerNorm
from .softmax import fused_softmax
from .attention_core import fused_attention_core
__all__ = [
"bias_dropout_add",
"bias_sigmod_ele",
"bias_ele_dropout_residual",
"LayerNorm",
"fused_softmax",
"fused_attention_core",
]
\ No newline at end of file
import math
import logging
import torch
from einops import rearrange
_triton_available = True
if _triton_available:
try:
from .triton.attention_core import attention_core_triton_kernel_wrapper
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
def _torch_attention_core(q, k, v, mask, bias):
scaling = 1. / math.sqrt(q.size(-1))
q = q * scaling
logits = torch.matmul(q, k.transpose(-1, -2))
logits += bias
logits += (1e20 * (mask - 1))[..., :, None, None, :]
weights = torch.nn.functional.softmax(logits.float(), -1).to(dtype=q.dtype)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
return weighted_avg
class FusedAttenionCoreFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, mask=None, bias=None):
if _triton_available:
o = attention_core_triton_kernel_wrapper(q, k, v, mask, bias)
else:
o = _torch_attention_core(q, k, v, mask, bias)
# ctx.save_for_backward(q, k, v, o, L, m, mask, bias)
# ctx.BLOCK = BLOCK
# ctx.grid = grid
# ctx.sm_scale = sm_scale
# ctx.BLOCK_DMODEL = Lk
return o
fused_attention_core = FusedAttenionCoreFunc.apply
\ No newline at end of file
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cassert>
#include <vector>
#include "compat.h"
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape=" << normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape << ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input, normalized_shape, n1, n2);
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta, int& n1, int& n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta,
epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape,
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean,
at::Tensor invvar, at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta,
epsilon, &grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
}
\ No newline at end of file
// part of code modified from https://github.com/NVIDIA/apex
#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "type_shim.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
inline __device__ void WelfordOnline(float val, float* mean, float* m2, float* count) {
*count += 1;
float delta1 = val - *mean;
*mean += delta1 / (*count);
float delta2 = val - *mean;
*m2 += delta1 * delta2;
}
inline __device__ void WelfordOnline(float b_mean, float b_m2, float b_count, float* mean,
float* m2, float* count) {
if (b_count == 0) {
return;
}
float new_count = *count + b_count;
float nb_n = b_count / new_count;
float delta = b_mean - *mean;
*mean += delta * nb_n;
*m2 += b_m2 + delta * delta * (*count) * nb_n;
*count = new_count;
}
__inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_m2,
float thread_count, float* mean, float* m2,
float* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = 1; mask < 32; mask *= 2) {
float b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
float b_count = __shfl_down_sync(0xffffffff, *count, mask);
WelfordOnline(b_mean, b_m2, b_count, mean, m2, count);
}
*mean = __shfl_sync(0xffffffff, *mean, 0, 32);
*m2 = __shfl_sync(0xffffffff, *m2, 0, 32);
*count = __shfl_sync(0xffffffff, *count, 0, 32);
}
template <typename T>
__global__ void fastfold_layernorm(T* input, T* output, T* gamma, T* beta, float* mean,
float* invvar, int rows, int cols, double epsilon) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
int lane_id = threadidx_y;
if (row_offset < rows) {
float buf[32];
float thread_mean = 0.f;
float thread_m2 = 0.f;
float thread_count = 0.f;
float warp_mean;
float warp_m2;
float warp_count;
T* row_input = input + row_offset * cols;
T* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
WelfordWarpAllReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2,
&warp_count);
float row_mean = warp_mean;
float row_variance = max(warp_m2 / warp_count, 0.f);
float row_inv_var = rsqrt(row_variance + epsilon);
if (lane_id == 0) {
mean[row_offset] = row_mean;
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<T>(buf[i]) * gamma[lane_id * cols_per_thread + i] +
beta[lane_id * cols_per_thread + i];
}
}
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int rows, int cols, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon) {
int grid = (rows + 3) / 4;
dim3 block(128);
if (output->dtype() == torch::kFloat32) {
fastfold_layernorm<float><<<grid, block>>>(
(float*)input->data_ptr(), (float*)output->data_ptr(), (float*)gamma->data_ptr(),
(float*)beta->data_ptr(), (float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows,
cols, epsilon);
} else if (output->dtype() == torch::kFloat16) {
fastfold_layernorm<at::Half><<<grid, block>>>(
(at::Half*)input->data_ptr(), (at::Half*)output->data_ptr(),
(at::Half*)gamma->data_ptr(), (at::Half*)beta->data_ptr(), (float*)mean->data_ptr(),
(float*)invvar->data_ptr(), rows, cols, epsilon);
} else if (output->dtype() == torch::kBFloat16) {
fastfold_layernorm<at::BFloat16><<<grid, block>>>(
(at::BFloat16*)input->data_ptr(), (at::BFloat16*)output->data_ptr(),
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr(),
(float*)mean->data_ptr(), (float*)invvar->data_ptr(), rows, cols, epsilon);
}
}
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory<float> {
__device__ float* getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <typename T, typename U, typename V>
__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off,
const int thr_load_col_off, const int i2_off,
const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off,
const int thr_load_col_off, const int i2_off,
const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, const T* __restrict__ input,
const int n1, const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon,
U* part_grad_gamma, U* part_grad_beta) {
const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y +
// (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename U, typename V>
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta,
const int part_size, const int n1, const int n2,
V* grad_gamma, V* grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputeGradInput(const V* __restrict__ dout, const T* __restrict__ input,
const int n1, const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon, const V* gamma,
T* grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1 * n2;
const V* k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1,
int n2, const V* gamma, const V* beta, double epsilon, T* grad_input,
V* grad_gamma, V* grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma =
at::empty({part_size, n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size, n1, n2,
grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma, grad_input);
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape,
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta) {
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(), input, n1, n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
}
\ No newline at end of file
#include <torch/extension.h>
at::Tensor softmax(at::Tensor input, long long rows, long long cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows, long long cols);
at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
long long cols);
at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
long long rows, long long cols);
at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols);
at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
at::Tensor bias, long long rows, long long cols);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax, "Softmax forward (CUDA)");
m.def("backward", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_mask_softmax_forward", &fused_mask_softmax_forward, "Softmax forward (CUDA)");
m.def("fused_mask_softmax_backward", &fused_mask_softmax_backward, "Softmax forward (CUDA)");
m.def("fused_mask_bias_softmax_forward", &fused_mask_bias_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_mask_bias_softmax_backward", &fused_mask_bias_softmax_backward,
"Softmax forward (CUDA)");
}
\ No newline at end of file
#include <c10/cuda/CUDAGuard.h>
#include <math_constants.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
__inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
int *num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
return cudaSuccess;
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
////////////////
template <typename T, int cols_per_thread>
__global__ void fastfold_softmax(T *input, T *output, long long rows, long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
float buf[cols_per_thread];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_input = input + row_offset * cols;
T *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]);
} else {
buf[i] = -1 * CUDART_INF_F;
}
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (lane_id * cols_per_thread + i < cols) {
row_output[lane_id * cols_per_thread + i] =
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
}
template <typename T, int block_size>
__global__ void fastfold_softmax_sm(T *input, T *output, long long rows, long long cols) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<float *>(shared_buf);
const int tid = threadIdx.x;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
float thread_max = -1 * CUDART_INF_F;
for (int id = tid; id < cols; id += block_size) {
buf[id] = static_cast<T>(input[row * cols + id]);
thread_max = max(thread_max, buf[id]);
}
const float row_max = BlockAllReduce<MaxOp, float, block_size>(thread_max);
float thread_sum = 0;
for (int id = tid; id < cols; id += block_size) {
buf[id] = __expf(buf[id] - row_max);
thread_sum += buf[id];
}
const float row_sum = BlockAllReduce<SumOp, float, block_size>(thread_sum);
for (int id = tid; id < cols; id += block_size) {
output[row * cols + id] = static_cast<T>(buf[id] / row_sum);
}
}
}
at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor output = at::empty_like(input);
int grid = (rows + 3) / 4;
dim3 block(128);
if (cols <= 32) {
if (input.dtype() == torch::kFloat32) {
fastfold_softmax<float, 1><<<grid, block>>>((float *)input.data_ptr(),
(float *)output.data_ptr(), rows, cols);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax<at::Half, 1><<<grid, block>>>(
(at::Half *)input.data_ptr(), (at::Half *)output.data_ptr(), rows, cols);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax<at::BFloat16, 1><<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)output.data_ptr(), rows, cols); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)output.data_ptr(), rows, cols); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols); \
} \
}
COLS_CASE(2)
COLS_CASE(3)
COLS_CASE(4)
COLS_CASE(5)
COLS_CASE(6)
COLS_CASE(7)
COLS_CASE(8)
COLS_CASE(9)
COLS_CASE(10)
COLS_CASE(11)
COLS_CASE(12)
COLS_CASE(13)
COLS_CASE(14)
COLS_CASE(15)
COLS_CASE(16)
COLS_CASE(17)
COLS_CASE(18)
COLS_CASE(19)
COLS_CASE(20)
COLS_CASE(21)
COLS_CASE(22)
COLS_CASE(23)
COLS_CASE(24)
COLS_CASE(25)
COLS_CASE(26)
COLS_CASE(27)
COLS_CASE(28)
COLS_CASE(29)
COLS_CASE(30)
COLS_CASE(31)
COLS_CASE(32)
#undef COLS_CASE
else {
int grid_dim;
constexpr int waves = 32;
GetNumBlocks(128, rows, waves, &grid_dim);
dim3 block(128);
const size_t smem = cols * sizeof(float);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_sm<float, 128><<<grid_dim, block, smem>>>(
(float *)input.data_ptr(), (float *)output.data_ptr(), rows, cols);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_sm<at::Half, 128><<<grid_dim, block, smem>>>(
(at::Half *)input.data_ptr(), (at::Half *)output.data_ptr(), rows, cols);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_sm<at::BFloat16, 128><<<grid_dim, block, smem>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols);
}
}
return output;
}
template <typename T>
__global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long long rows,
long long cols) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_d_output = d_output + row_offset * cols;
T *row_output = output + row_offset * cols;
T *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]);
}
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
thread_sum += y_buf[i] * dy_buf[i];
}
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
if (lane_id * cols_per_thread + i < cols) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]);
}
}
}
}
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long rows,
long long cols) {
CHECK_INPUT(output);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
at::Tensor grad_input = at::empty_like(output);
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_grad<float><<<grid, block>>>((float *)d_output.data_ptr(),
(float *)output.data_ptr(),
(float *)grad_input.data_ptr(), rows, cols);
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_grad<at::Half>
<<<grid, block>>>((at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), rows, cols);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), rows, cols);
}
return grad_input;
}
////////////////
template <typename T, int cols_per_thread>
__global__ void fastfold_softmax_mask(T *input, T *mask, T *output, long long rows, long long cols,
int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
float buf[cols_per_thread];
int lane_id = threadidx_y;
T *row_input = input + row_offset * cols;
T *row_output = output + row_offset * cols;
T *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 1e9;
} else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]);
}
} else {
buf[i] = -1 * CUDART_INF_F;
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (lane_id * cols_per_thread + i < cols) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
template <typename T, int block_size>
__global__ void fastfold_softmax_mask_sm(T *input, T *mask, T *output, long long rows,
long long cols, int head) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<float *>(shared_buf);
const int tid = threadIdx.x;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
T *mask_ptr = mask + ((row / (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
for (int id = tid; id < cols; id += block_size) {
if (mask_ptr[id] == 0) {
buf[id] = -1 * 1e9;
} else {
buf[id] = input[row * cols + id];
}
thread_max = max(thread_max, buf[id]);
}
const float row_max = BlockAllReduce<MaxOp, float, block_size>(thread_max);
float thread_sum = 0;
for (int id = tid; id < cols; id += block_size) {
buf[id] = __expf(buf[id] - row_max);
thread_sum += buf[id];
}
const float row_sum = BlockAllReduce<SumOp, float, block_size>(thread_sum);
for (int id = tid; id < cols; id += block_size) {
output[row * cols + id] = buf[id] / row_sum;
}
}
}
at::Tensor fused_mask_softmax_forward(at::Tensor input, at::Tensor mask, long long rows,
long long cols) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int head = input.sizes()[2];
// at::Tensor output = at::empty_like(input);
int grid = (rows + 3) / 4;
dim3 block(128);
if (cols <= 32) {
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_mask<float, 1>
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(),
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_mask<at::Half, 1>
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_mask<at::BFloat16, 1>
<<<grid, block>>>((at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)input.data_ptr(), rows, cols, head);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_mask<float, col_per_thread> \
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_mask<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), rows, cols, head); \
} \
}
COLS_CASE(2)
COLS_CASE(3)
COLS_CASE(4)
COLS_CASE(5)
COLS_CASE(6)
COLS_CASE(7)
COLS_CASE(8)
COLS_CASE(9)
COLS_CASE(10)
COLS_CASE(11)
COLS_CASE(12)
COLS_CASE(13)
COLS_CASE(14)
COLS_CASE(15)
COLS_CASE(16)
COLS_CASE(17)
COLS_CASE(18)
COLS_CASE(19)
COLS_CASE(20)
COLS_CASE(21)
COLS_CASE(22)
COLS_CASE(23)
COLS_CASE(24)
COLS_CASE(25)
COLS_CASE(26)
COLS_CASE(27)
COLS_CASE(28)
COLS_CASE(29)
COLS_CASE(30)
COLS_CASE(31)
COLS_CASE(32)
#undef COLS_CASE
else {
int grid_dim;
constexpr int waves = 32;
GetNumBlocks(128, rows, waves, &grid_dim);
dim3 block(128);
const size_t smem = cols * sizeof(float);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_mask_sm<float, 128>
<<<grid, block, smem>>>((float *)input.data_ptr(), (float *)mask.data_ptr(),
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_mask_sm<at::Half, 128>
<<<grid, block, smem>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_mask_sm<at::BFloat16, 128><<<grid, block, smem>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)input.data_ptr(), rows, cols, head);
}
}
return input;
}
template <typename T>
__global__ void fastfold_softmax_mask_grad(T *d_output, T *output, T *d_input, T *mask,
long long rows, long long cols, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
} else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_d_output = d_output + row_offset * cols;
T *row_output = output + row_offset * cols;
T *row_d_input = d_input + row_offset * cols;
T *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
y_buf[i] = static_cast<T>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<T>(row_d_output[lane_id * cols_per_thread + i]);
}
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
thread_sum += y_buf[i] * dy_buf[i];
}
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
if (lane_id * cols_per_thread + i < cols) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<T>((dy_buf[i] - warp_sum) * y_buf[i]);
} else {
row_d_input[lane_id * cols_per_thread + i] = 0;
}
}
}
}
}
at::Tensor fused_mask_softmax_backward(at::Tensor d_output, at::Tensor output, at::Tensor mask,
long long rows, long long cols) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
int head = output.sizes()[2];
at::Tensor grad_input = at::empty_like(output);
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_mask_grad<float><<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_mask_grad<at::Half><<<grid, block>>>(
(at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_mask_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
head);
}
return grad_input;
}
////////////////
template <typename T, int cols_per_thread>
__global__ void fastfold_softmax_mask_bias(T *input, T *mask, T *bias, T *output, long long rows,
long long cols, int head) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)blockIdx.x * 4 + threadidx_x;
float buf[cols_per_thread];
int lane_id = threadidx_y;
T *row_input = input + row_offset * cols;
T *row_output = output + row_offset * cols;
T *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
T *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
if (lane_id * cols_per_thread + i < cols) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * 10e9;
} else {
buf[i] = static_cast<T>(row_input[lane_id * cols_per_thread + i]) +
static_cast<T>(bias_ptr[lane_id * cols_per_thread + i]);
}
} else {
buf[i] = -1 * CUDART_INF_F;
}
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (lane_id * cols_per_thread + i < cols) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
template <typename T, int block_size>
__global__ void fastfold_softmax_mask_bias_sm(T *input, T *mask, T *bias, T *output, long long rows,
long long cols, int head) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto *buf = reinterpret_cast<float *>(shared_buf);
const int tid = threadIdx.x;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
T *mask_ptr = mask + ((row / (head * cols)) * cols);
T *bias_ptr = bias + ((row % (head * cols)) * cols);
float thread_max = -1 * CUDART_INF_F;
for (int id = tid; id < cols; id += block_size) {
if (mask_ptr[id] == 0) {
buf[id] = -1 * 1e9;
} else {
buf[id] = input[row * cols + id] + bias_ptr[id];
}
thread_max = max(thread_max, buf[id]);
}
const float row_max = BlockAllReduce<MaxOp, float, block_size>(thread_max);
float thread_sum = 0;
for (int id = tid; id < cols; id += block_size) {
buf[id] = __expf(buf[id] - row_max);
thread_sum += buf[id];
}
const float row_sum = BlockAllReduce<SumOp, float, block_size>(thread_sum);
for (int id = tid; id < cols; id += block_size) {
output[row * cols + id] = buf[id] / row_sum;
}
}
}
at::Tensor fused_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
long long rows, long long cols) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
CHECK_INPUT(bias);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int head = input.sizes()[2];
// at::Tensor output = at::empty_like(input);
int grid = (rows + 3) / 4;
dim3 block(128);
if (cols <= 32) {
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_mask_bias<float, 1><<<grid, block>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(),
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_mask_bias<at::Half, 1><<<grid, block>>>(
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_mask_bias<at::BFloat16, 1>
<<<grid, block>>>((at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(),
rows, cols, head);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_mask_bias<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
head); \
} \
}
COLS_CASE(2)
COLS_CASE(3)
COLS_CASE(4)
COLS_CASE(5)
COLS_CASE(6)
COLS_CASE(7)
COLS_CASE(8)
COLS_CASE(9)
COLS_CASE(10)
COLS_CASE(11)
COLS_CASE(12)
COLS_CASE(13)
COLS_CASE(14)
COLS_CASE(15)
COLS_CASE(16)
COLS_CASE(17)
COLS_CASE(18)
COLS_CASE(19)
COLS_CASE(20)
COLS_CASE(21)
COLS_CASE(22)
COLS_CASE(23)
COLS_CASE(24)
COLS_CASE(25)
COLS_CASE(26)
COLS_CASE(27)
COLS_CASE(28)
COLS_CASE(29)
COLS_CASE(30)
COLS_CASE(31)
COLS_CASE(32)
#undef COLS_CASE
else {
int grid_dim;
constexpr int waves = 32;
GetNumBlocks(128, rows, waves, &grid_dim);
dim3 block(128);
const size_t smem = cols * sizeof(float);
if (input.dtype() == torch::kFloat32) {
fastfold_softmax_mask_bias_sm<float, 128><<<grid, block, smem>>>(
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(),
(float *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kFloat16) {
fastfold_softmax_mask_bias_sm<at::Half, 128><<<grid, block, smem>>>(
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(),
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head);
} else if (input.dtype() == torch::kBFloat16) {
fastfold_softmax_mask_bias_sm<at::BFloat16, 128><<<grid, block, smem>>>(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols,
head);
}
}
return input;
}
at::Tensor fused_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor output, at::Tensor mask,
at::Tensor bias, long long rows, long long cols) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
int head = output.sizes()[2];
at::Tensor grad_input = at::empty_like(output);
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
fastfold_softmax_mask_grad<float><<<grid, block>>>(
(float *)d_output.data_ptr(), (float *)output.data_ptr(),
(float *)grad_input.data_ptr(), (float *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kFloat16) {
fastfold_softmax_mask_grad<at::Half><<<grid, block>>>(
(at::Half *)d_output.data_ptr(), (at::Half *)output.data_ptr(),
(at::Half *)grad_input.data_ptr(), (at::Half *)mask.data_ptr(), rows, cols, head);
} else if (output.dtype() == torch::kBFloat16) {
fastfold_softmax_mask_grad<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)d_output.data_ptr(), (at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), rows, cols,
head);
}
return grad_input;
}
// modified from https://github.com/NVIDIA/apex
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
\ No newline at end of file
import importlib
import torch
fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fastfold_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fastfold_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
import importlib
fastfold_softmax_cuda = importlib.import_module("fastfold_softmax_cuda")
def softmax_cuda_kernel_wrapper(input_, mask_, bias_, rows, cols):
if bias_ is not None:
output = fastfold_softmax_cuda.fused_mask_bias_softmax_forward(input_, mask_, bias_, rows, cols)
elif mask_ is not None:
output = fastfold_softmax_cuda.fused_mask_softmax_forward(input_, mask_, rows, cols)
else:
output = fastfold_softmax_cuda.forward(input_, rows, cols)
return output
def softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, rows, cols):
if mask_ is not None:
grad_input = fastfold_softmax_cuda.fused_mask_softmax_backward(grad_output, output, mask_, rows, cols)
else:
grad_input = fastfold_softmax_cuda.backward(grad_output, output, rows, cols)
return grad_input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment