Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
from .options import _set_jit_fusion_options
_set_jit_fusion_options()
\ No newline at end of file
import torch
import torch.nn.functional as F
@torch.jit.script
def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z
# @torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor, prob: float,
training: bool) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=training) * (g * (ab + b))
import torch
JIT_OPTIONS_SET = False
def _set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# # nvfuser
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
# torch._C._jit_override_can_fuse_on_cpu(False)
# torch._C._jit_override_can_fuse_on_gpu(False)
# torch._C._jit_set_texpr_fuser_enabled(False)
# torch._C._jit_set_nvfuser_enabled(True)
# torch._C._debug_set_autodiff_subgraph_inlining(False)
# else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
JIT_OPTIONS_SET = True
\ No newline at end of file
# part of code modified from https://github.com/NVIDIA/apex
from .cuda_native.layer_norm import FusedLayerNormAffineFunction
import logging
import numbers
import torch
from torch.nn.parameter import Parameter
_triton_available = True
if _triton_available:
try:
from .triton.layer_norm import LayerNormTritonFunc
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
class FusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(FusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, input):
if len(input.shape) >= 3 and input.shape[-3] > 4000:
out = torch.empty_like(input)
# set max chunk_size = dim / 2, to max compute efficiency
chunk_size = min(4000 * 4000 // input.shape[-3], (input.shape[-3] + 1) // 2)
if len(input.shape) == 3:
for i in range(input.shape[-3]):
out[i:i + chunk_size] = self.kernel_forward(input[i:i + chunk_size])
elif len(input.shape) == 4:
for j in range(input.shape[-4]):
for i in range(0, input.shape[-3], chunk_size):
out[j, i:i + chunk_size] = self.kernel_forward(input[j, i:i + chunk_size])
else:
raise RuntimeError("Shape" + input.shape + "not implemented for layernorm yet!")
return out
else:
return self.kernel_forward(input)
def kernel_forward(self, input):
if _triton_available:
return LayerNormTritonFunc.apply(input, self.normalized_shape, self.weight, self.bias,
self.eps)
else:
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
from functools import reduce
from operator import mul
import logging
import torch
_triton_available = True
if _triton_available:
try:
from .triton.softmax import softmax_triton_kernel_wrapper
from .triton.softmax import softmax_grad_triton_kernel_wrapper
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
from .cuda_native.softmax import softmax_cuda_kernel_wrapper
from .cuda_native.softmax import softmax_grad_cuda_kernel_wrapper
class FusedSoftmaxFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask=None, bias=None):
input_ = input.contiguous()
mask_, bias_ = None, None
ctx.use_bias = False
if mask is not None:
mask_ = mask.contiguous()
if bias is not None:
bias_ = bias.contiguous()
ctx.use_bias = True
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
if _triton_available:
output = softmax_triton_kernel_wrapper(input_, mask_, bias_, ctx.rows, ctx.cols)
else:
output = softmax_cuda_kernel_wrapper(input_, mask_, bias_, ctx.rows, ctx.cols)
ctx.save_for_backward(output, mask_)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
output, mask_ = ctx.saved_tensors
if _triton_available:
grad_input = softmax_grad_triton_kernel_wrapper(grad_output, output, ctx.rows, ctx.cols)
else:
grad_input = softmax_grad_cuda_kernel_wrapper(grad_output, output, mask_, ctx.rows,
ctx.cols)
grad_bias = None
if ctx.use_bias:
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input, None, grad_bias
fused_softmax = FusedSoftmaxFunc.apply
\ No newline at end of file
import math
import torch
from einops import rearrange
import triton
import triton.language as tl
# CREDITS: Initially inspired by the Triton tutorial
@triton.jit
def _attention_core(Q, K, V, mask, bias, sm_scale, TMP, Out, stride_qz, stride_qh, stride_qm,
stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh,
stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX,
BATCH, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# Initialize pointers to bias, mask
if use_bias:
batch_2 = Z // BATCH
off_hz_bias = (off_hz // (batch_2 * H) * H) + (off_hz % H)
offs_base_bias = off_hz_bias * (N_CTX * N_CTX) + offs_m[:, None] * N_CTX + offs_n[None, :]
if use_mask:
off_hz_mask = (off_hz // H)
offs_base_mask = off_hz_mask * N_CTX
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q_load_mask = offs_m[:, None] < N_CTX
q = tl.load(q_ptrs, mask=q_load_mask, other=0.0)
# loop over k, v and update accumulator
for start_n in range(0, N_CTX, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
load_mask = (start_n + offs_n)[:, None] < N_CTX
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn, mask=load_mask, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= N_CTX, float("-1e20"), qk)
qk = tl.where((start_n + offs_n)[None, :] >= N_CTX, float("-1e20"), qk)
if use_bias:
bias_load_mask = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
bias_load_mask = tl.where(offs_m[:, None] >= N_CTX, 1., bias_load_mask)
bias_load_mask = tl.where((start_n + offs_n)[None, :] >= N_CTX, 1., bias_load_mask)
bias_data = tl.load(bias + offs_base_bias + start_n,
mask=(bias_load_mask == 0.),
other=0.)
qk += bias_data
if use_mask:
mask_data = tl.load(mask + offs_base_mask + offs_n + start_n,
mask=(start_n + offs_n) < N_CTX,
other=0.)
qk = tl.where(mask_data[None, :] == 0., float("-1e20"), qk)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale, mask=(offs_m < N_CTX))
acc_scale = tl.load(TMP + off_hz * N_CTX + start_m * BLOCK_M + tl.arange(0, BLOCK_M),
mask=(start_m * BLOCK_M + tl.arange(0, BLOCK_M) < N_CTX),
other=float(0.)) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
load_mask = (start_n + offs_n)[:, None] < N_CTX
v = tl.load(v_ptrs + start_n * stride_vn, mask=load_mask, other=0.)
p = p.to(Q.dtype.element_ty)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
# l_ptrs = L + off_hz * N_CTX + offs_m
# m_ptrs = M + off_hz * N_CTX + offs_m
# tl.store(l_ptrs, l_i)
# tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
out_store_mask = offs_m[:, None] < N_CTX
tl.store(out_ptrs, acc, mask=out_store_mask)
def attention_core_triton_kernel_wrapper(q, k, v, mask, bias):
assert (q.dtype in [torch.float16,
torch.bfloat16]), "triton flash attention only support float16/bfloat16 now"
q_ori_size = list(q.size())
batch = q_ori_size[0]
if len(q_ori_size) == 5:
q = rearrange(q, 'b1 b2 h n d -> (b1 b2) h n d')
k = rearrange(k, 'b1 b2 h n d -> (b1 b2) h n d')
v = rearrange(v, 'b1 b2 h n d -> (b1 b2) h n d')
sm_scale = 1. / math.sqrt(q.size(-1))
# q *= sm_scale
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_attention_core[grid](
q,
k,
v,
mask,
bias,
sm_scale,
tmp,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
batch,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
use_mask=(mask != None),
use_bias=(bias != None),
num_warps=num_warps,
num_stages=1,
)
if len(q_ori_size) == 5:
o = rearrange(o, '(b1 b2) h n d -> b1 b2 n (h d)', b1=batch)
return o
import torch
import triton
import triton.language as tl
@triton.jit
def _layer_norm_fwd_fused(
Out,
A,
Weight,
Bias,
Mean,
Rstd,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.,).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(
_DA,
_DOut,
_A,
Weight,
Mean,
Rstd,
stride,
NumRows,
NumCols,
eps,
BLOCK_SIZE_N: tl.constexpr,
):
# position of elements processed by this program
pid = tl.program_id(0)
row = pid
A = _A + row * stride
DOut = _DOut + row * stride
DA = _DA + row * stride
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# load data to SRAM
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
_mean1 += a_hat * wdout
_mean2 += wdout
mean1 = tl.sum(_mean1, axis=0) / NumCols
mean2 = 0.
mean2 = tl.sum(_mean2, axis=0) / NumCols
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DA + cols, da, mask=mask)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(
A,
DOut,
Mean,
Var,
DW,
DB,
M,
N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
tl.store(DB + cols, sum_db, mask=cols < N)
class LayerNormTritonFunc(torch.autograd.Function):
def forward(ctx, a_raw, normalized_shape, weight, bias, eps):
# allocate output
a = a_raw.contiguous()
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](
out,
a_arg,
weight,
bias,
mean,
rstd,
a_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.save_for_backward(
a,
weight,
bias,
mean,
rstd,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return out
@staticmethod
def backward(ctx, dout):
assert dout.is_contiguous()
a, weight, bias, mean, var = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
# allocate output
da = torch.empty_like(dout)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = a.reshape(-1, a.shape[-1])
M, N = x_arg.shape
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
_layer_norm_bwd_dx_fused[(M,)](
da,
dout,
a,
weight,
mean,
var,
x_arg.stride(0),
M,
N,
ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
if N > 384:
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 4
BLOCK_SIZE_M = 256
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](a,
dout,
mean,
var,
dweight,
dbias,
M,
N,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps)
return (da, None, dweight, dbias, None)
import torch
import triton
import triton.language as tl
@triton.jit
def _softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
if use_bias:
bias = tl.load(bias_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
row += bias
if use_mask:
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=float("-inf")).to(tl.float32)
row = tl.where(mask == 0, float("-1e20"), row)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
@triton.jit
def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_cols,
is_bf16: tl.constexpr):
output_row = tl.load(output_ptrs, mask=col_offsets < n_cols, other=float(0))
d_output_row = tl.load(d_output_ptrs, mask=col_offsets < n_cols, other=float(0))
if is_bf16:
output_row = output_row.to(tl.float32)
d_output_row = d_output_row.to(tl.float32)
row_sum = tl.sum(output_row * d_output_row, axis=0)
d_softmax_output = (d_output_row - row_sum) * output_row
tl.store(d_input_ptrs, d_softmax_output, mask=col_offsets < n_cols)
@triton.jit
def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + row_idx * input_row_stride
output_row_ptr = output_ptr + row_idx * output_row_stride
input_ptrs = input_row_ptr + col_offsets
output_ptrs = output_row_ptr + col_offsets
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + (row_idx // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + (row_idx % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask,
use_bias)
@triton.jit
def softmax_mask_bias_kernel_two_rows(output_ptr, input_ptr, mask_ptr, bias_ptr, input_row_stride,
output_row_stride, n_cols, n_heads, BLOCK_SIZE: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
input_row_ptr = input_ptr + 2 * row_idx * input_row_stride
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
input_ptrs = input_row_ptr + col_offsets
output_ptrs = output_row_ptr + col_offsets
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + ((2 * row_idx) % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs, output_ptrs, mask_ptrs, bias_ptrs, col_offsets, n_cols, use_mask,
use_bias)
mask_ptrs = input_ptrs # place holder, not use if use_mask == False
if use_mask:
mask_row_ptr = mask_ptr + ((2 * row_idx + 1) // (n_heads * n_cols)) * n_cols
mask_ptrs = mask_row_ptr + col_offsets
bias_ptrs = input_ptrs # place holder, not use if use_bias == False
if use_bias:
bias_row_ptr = bias_ptr + ((2 * row_idx + 1) % (n_heads * n_cols)) * n_cols
bias_ptrs = bias_row_ptr + col_offsets
_softmax_core(input_ptrs + n_cols, output_ptrs + n_cols, mask_ptrs, bias_ptrs, col_offsets,
n_cols, use_mask, use_bias)
@triton.jit
def softmax_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, d_output_row_stride,
output_row_stride, d_input_row_stride, n_cols, BLOCK_SIZE: tl.constexpr,
is_bf16: tl.constexpr):
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + row_idx * output_row_stride
d_output_row_ptr = d_output_ptr + row_idx * d_output_row_stride
d_input_row_ptr = d_input_ptr + row_idx * d_input_row_stride
output_ptrs = output_row_ptr + col_offsets
d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_cols, is_bf16)
@triton.jit
def softmax_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, d_output_row_stride,
output_row_stride, d_input_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr, is_bf16: tl.constexpr):
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
output_row_ptr = output_ptr + 2 * row_idx * output_row_stride
d_output_row_ptr = d_output_ptr + 2 * row_idx * d_output_row_stride
d_input_row_ptr = d_input_ptr + 2 * row_idx * d_input_row_stride
output_ptrs = output_row_ptr + col_offsets
d_output_ptrs = d_output_row_ptr + col_offsets
d_input_ptrs = d_input_row_ptr + col_offsets
_softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, col_offsets, n_cols, is_bf16)
_softmax_grad_core(output_ptrs + n_cols, d_output_ptrs + n_cols, d_input_ptrs + n_cols,
col_offsets, n_cols, is_bf16)
def softmax_triton_kernel_wrapper(x, mask, bias, n_rows, n_cols):
y = torch.empty_like(x)
n_heads = x.shape[2]
num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
if BLOCK_SIZE >= 1024:
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
_dispatch_kernel = softmax_mask_bias_kernel
_grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_mask_bias_kernel_two_rows
_grid = (n_rows // 2,)
_dispatch_kernel[_grid](
y,
x,
mask,
bias,
x.stride(-2),
y.stride(-2),
n_cols,
n_heads,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
use_mask=(mask != None),
use_bias=(bias != None),
)
return y
def softmax_grad_triton_kernel_wrapper(grad_output, output, n_rows, n_cols):
grad_input = torch.empty_like(grad_output)
num_warps = 1
BLOCK_SIZE = triton.next_power_of_2(n_cols)
if BLOCK_SIZE >= 1024:
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
is_bf16 = (output.dtype == torch.bfloat16)
_dispatch_kernel = softmax_grad_kernel
_grid = (n_rows,)
if n_cols <= 128 and n_rows % 2 == 0:
_dispatch_kernel = softmax_grad_kernel_two_rows
_grid = (n_rows // 2,)
_dispatch_kernel[_grid](
grad_output,
output,
grad_input,
grad_output.stride(-2),
output.stride(-2),
grad_output.stride(-2),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
is_bf16=is_bf16,
)
return grad_input
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn.kernel import LayerNorm, bias_dropout_add
from fastfold.model.fastnn.ops import (ChunkMSARowAttentionWithPairBias, ChunkTransition,
SelfAttention, GlobalAttention, Transition,
ChunkMSAColumnGlobalAttention, OutProductMean)
from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm import gather, scatter, row_to_col, scatter
from fastfold.distributed.comm_async import gather_async, All_to_All_Async, All_to_All_Async_Opp
from fastfold.model.fastnn.triangle import PairCore
class MSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(MSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z, M_mask):
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training)
class MSAColumnAttention(nn.Module):
def __init__(self, d_node, c=32, n_head=8):
super(MSAColumnAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8):
super(MSAColumnGlobalAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.global_attention = GlobalAttention(
qkv_dim=d_node, c=c, n_head=n_head, out_dim=d_node
)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
M = self.global_attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSACore, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop)
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
# split node in row-axis
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
class ExtraMSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSACore, self).__init__()
self.MSARowAttentionWithPairBias = ChunkMSARowAttentionWithPairBias(
d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8
)
self.MSAColumnAttention = ChunkMSAColumnGlobalAttention(d_node=d_node, c=8)
self.MSATransition = ChunkTransition(d=d_node)
def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
def inplace(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias.inplace(node, pair[0], node_mask_row)
node[0] = row_to_col(node[0])
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention.inplace(node, node_mask_col)
node = self.MSATransition.inplace(node)
return node
class ExtraMSABlock(nn.Module):
def __init__(
self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer=False
):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSACore(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairCore(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(
m, (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z = torch.nn.functional.pad(
z, (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m = scatter(m, dim=1) if not self.is_multimer else scatter(m, dim=2)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = gather(m, dim=1) if not self.is_multimer else gather(m, dim=2)
z = gather(z, dim=1)
m = m[:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z = z[:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m = m.squeeze(0)
z = z.squeeze(0)
return m, z
def inplace(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m[0] = m[0].unsqueeze(0)
z[0] = z[0].unsqueeze(0)
m[0] = torch.nn.functional.pad(
m[0], (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z[0] = torch.nn.functional.pad(
z[0], (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m[0] = scatter(m[0], dim=1) if not self.is_multimer else scatter(m[0], dim=2)
z[0] = scatter(z[0], dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = [z[0].clone()]
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
m[0] = gather(m[0], dim=1) if not self.is_multimer else gather(m[0], dim=2)
z[0] = gather(z[0], dim=1)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
return m, z
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
no_blocks: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
def inplace(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b.inplace(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from typing import Tuple
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.model.fastnn.kernel import fused_softmax
from .initializer import glorot_uniform_af
from fastfold.model.fastnn.kernel import bias_sigmod_ele, bias_ele_dropout_residual, bias_dropout_add
from fastfold.distributed import gather, scatter
from fastfold.distributed.comm_async import gather_async, gather_async_opp, get_world_size, get_rank, broadcast_sync, broadcast_async, broadcast_async_opp
CHUNK_SIZE = None
DEBUG = False
def set_chunk_size(chunk_size):
global CHUNK_SIZE
CHUNK_SIZE = chunk_size
def get_chunk_size():
global CHUNK_SIZE
return CHUNK_SIZE
class DropoutRowwise(nn.Module):
def __init__(self, p):
super(DropoutRowwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class DropoutColumnwise(nn.Module):
def __init__(self, p):
super(DropoutColumnwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class Transition(nn.Module):
def __init__(self, d, n=4):
super(Transition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
x = self.norm(src)
x = self.linear2(F.relu(self.linear1(x)))
return src + x
class ChunkTransition(nn.Module):
def __init__(self, d, n=4):
super(ChunkTransition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
if CHUNK_SIZE == None:
out = self.norm(src)
out = self.linear2(F.relu(self.linear1(out)))
else:
chunk_size = CHUNK_SIZE * 48
para_dim = src.shape[1]
out = torch.empty_like(src)
for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10:
break
x = self.norm(src[:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x)))
out[:, ax:ax + chunk_size, :, :] = x
out.add_(src)
return out
def inplace(self, src):
para_dim = src[0].shape[1]
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE * 48
for ax in range(0, para_dim, chunk_size):
if DEBUG and ax > 10:
break
x = self.norm(src[0][:, ax:ax + chunk_size, :, :])
x = self.linear2(F.relu(self.linear1(x)))
src[0][:, ax:ax + chunk_size, :, :] += x
return src
class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
super(OutProductMean, self).__init__()
self.layernormM = LayerNorm(n_feat)
self.linear_a = Linear(n_feat, n_feat_proj)
self.linear_b = Linear(n_feat, n_feat_proj)
self.o_linear = Linear(n_feat_proj * n_feat_proj,
n_feat_out,
initializer='zero',
use_bias=True)
self.n_feat_proj = n_feat_proj
def forward(self, M, M_mask, Z_raw):
Z = torch.empty_like(Z_raw)
M = self.layernormM(M)
right_act = self.linear_b(M)
right_act_all, work = gather_async(right_act, dim=2)
# right_act_all = gather(right_act, dim=2)
left_act = self.linear_a(M)
M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act
norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask) + 1e-3
right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all
if CHUNK_SIZE == None:
out = torch.einsum('bsid, bsje->bijde', left_act, right_act_all)
out = rearrange(out, 'b i j d e -> b i j (d e)')
out = self.o_linear(out)
Z = out / norm
else:
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z[:, ax:ax + chunk_size, :, :] = O / norm0
return Z + Z_raw
def inplace(self, M, M_mask, Z_raw):
chunk_size = CHUNK_SIZE
if len(M.shape) == 4:
para_dim = M.shape[1]
left_act = torch.empty((M.shape[0], M.shape[1], M.shape[2], self.n_feat_proj), dtype=M.dtype, device=M.device)
right_act = torch.empty((M.shape[0], M.shape[1], M.shape[2], self.n_feat_proj), dtype=M.dtype, device=M.device)
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = chunk_size * 32
for ax in range(0, para_dim, chunk_size):
m = self.layernormM(M[:, ax:ax + chunk_size, :, :])
right_act[:, ax:ax + chunk_size, :, :] = self.linear_b(m)
left_act[:, ax:ax + chunk_size, :, :] = self.linear_a(m)
else:
para_dim = M.shape[0]
left_act = torch.empty((M.shape[0], M.shape[1], self.n_feat_proj), dtype=M.dtype, device=M.device)
right_act = torch.empty((M.shape[0], M.shape[1], self.n_feat_proj), dtype=M.dtype, device=M.device)
if CHUNK_SIZE == None:
chunk_size = para_dim
else:
chunk_size = chunk_size * 32
for ax in range(0, para_dim, chunk_size):
m = self.layernormM(M[ax:ax + chunk_size, :, :])
right_act[ax:ax + chunk_size, :, :] = self.linear_b(m)
left_act[ax:ax + chunk_size, :, :] = self.linear_a(m)
right_act_all, work = gather_async(right_act, dim=2)
# right_act_all = gather(right_act, dim=2)
M_mask = M_mask.unsqueeze(-1)
M_mask_col = scatter(M_mask, dim=2)
left_act = M_mask_col * left_act
norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask) + 1e-3
right_act_all = gather_async_opp(right_act_all, work, dim=2)
right_act_all = M_mask * right_act_all
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
for ax in range(0, para_dim, chunk_size):
left_act_part = left_act[:, :, ax:ax + chunk_size, :]
O = torch.einsum('bsid,bsje->bijde', left_act_part, right_act_all)
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z_raw[0][:, ax:ax + chunk_size, :, :] += O / norm0
return Z_raw
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
feature_in: int,
feature_out: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
):
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
self.use_bias = use_bias
if initializer == 'linear':
glorot_uniform_af(self.weight, gain=1.0)
elif initializer == 'relu':
glorot_uniform_af(self.weight, gain=2.0)
elif initializer == 'zeros':
nn.init.zeros_(self.weight)
if self.use_bias:
with torch.no_grad():
self.bias.fill_(bias_init)
class SelfAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
super(SelfAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.gating = gating
self.last_bias_fuse = last_bias_fuse
self.scaling = self.c**(-0.5)
self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear', use_bias=False)
# self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if gating:
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
self.o_linear = Linear(n_head * c,
out_dim,
initializer='zero',
use_bias=(not last_bias_fuse))
def forward(self, in_data, mask, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param mask: None or [batch_size1, batch_size2, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
if nonbatched_bias is not None:
if nonbatched_bias[-1] == -1:
bias = nonbatched_bias[0]
else:
# logits += nonbatched_bias.unsqueeze(1)
bias = gather_async_opp(*nonbatched_bias, dim=1)
bias = rearrange(bias, 'b q k h -> b h q k')
if CHUNK_SIZE == None:
qkv = self.to_qkv(in_data).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
weights = fused_softmax(logits, mask, bias.unsqueeze(1))
else:
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
else:
para_dim = in_data.shape[1]
chunk_size = CHUNK_SIZE
output = []
for ax in range(0, para_dim, chunk_size):
in_data_part = in_data[:, ax:ax + chunk_size, :, :]
mask_part = mask[:, ax:ax + chunk_size, :]
qkv = self.to_qkv(in_data_part).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
# logits += bias.unsqueeze(1)
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part, bias.unsqueeze(1))
else:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data_part)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output.append(self.o_linear(weighted_avg))
output = torch.cat(output, dim=1)
return output
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class AsyncChunkTriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(AsyncChunkTriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_row):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_row.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
right_proj_act, work = gather_async(right_proj_act.contiguous(), dim=1)
g = torch.sigmoid(self.output_gate(Z))
left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
right_proj_act = gather_async_opp(right_proj_act, work, dim=1)
p = torch.matmul(left_proj_act, permute_final_dims(right_proj_act, (2, 1, 0)),)
ab = permute_final_dims(p, (1, 2, 0))
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw.shape[1]
chunk_size = CHUNK_SIZE * 32
world_size = get_world_size()
rank = get_rank()
output = torch.empty_like(Z_raw)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
zi = Z_raw[:, i:i + chunk_size, :, :]
zi = self.layernorm1(zi)
gi = torch.sigmoid(self.left_right_gate(zi))
i_left_right_proj_act = self.left_right_projection(zi)
i_left_right_proj_act = Z_mask_row[:, i:i + chunk_size, :].unsqueeze(-1) * i_left_right_proj_act
i_left_right_proj_act *= gi
left_proj_act, _ = i_left_right_proj_act.chunk(2, dim=-1)
left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
for j in range(0, para_dim, chunk_size):
zj = Z_raw[:, j:j + chunk_size, :, :]
zj = self.layernorm1(zj)
gj = torch.sigmoid(self.left_right_gate(zj))
j_left_right_proj_act = self.left_right_projection(zj)
j_left_right_proj_act = Z_mask_row[:, j:j + chunk_size, :].unsqueeze(-1) * j_left_right_proj_act
j_left_right_proj_act *= gj
_, right_proj_act = j_left_right_proj_act.chunk(2, dim=-1)
right_proj_act = right_proj_act.contiguous()
work = None
right_proj_act_tmp = torch.empty_like(right_proj_act)
for k in range(0, world_size):
if world_size > 1:
if work:
broadcast_async_opp(work) # collect last broadcast
if k != rank:
right_proj_act_rec = right_proj_act_tmp.clone()
else: # init first broadcast
if k == rank:
broadcast_sync(k, right_proj_act, host=True)
else:
right_proj_act_tmp = broadcast_sync(k, right_proj_act, host=False)
right_proj_act_rec = right_proj_act_tmp.clone()
if k + 1 != world_size: # launch next broadcast
if k + 1 == rank:
work = broadcast_async(k + 1, right_proj_act, host=True)
else:
work = broadcast_async(k + 1, right_proj_act_tmp, host=False)
if k == rank: # broadcast self right_proj_act
p = torch.matmul(
left_proj_act,
permute_final_dims(right_proj_act, (2, 1, 0)),
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, i:i + chunk_size, j_global:min(j_global + chunk_size, para_dim * (k + 1)), :] = p
else: # receive others broadcast
p = torch.matmul(
left_proj_act,
permute_final_dims(right_proj_act_rec, (2, 1, 0)),
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, i:i + chunk_size, j_global:min(j_global + chunk_size, para_dim * (k + 1)), :] = p
dropout_mask = torch.ones_like(Z_raw[:, 0:1, :, :], device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, Z_raw.shape[1], chunk_size):
z_raw = Z_raw[:, i:i + chunk_size, :, :]
g = torch.sigmoid(self.output_gate(self.layernorm1(z_raw)))
z = output[:, i:i + chunk_size, :, :]
z = self.output_projection(self.layernorm2(z))
z = bias_ele_dropout_residual(z,
self.output_bias,
g,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = z
return output
class AsyncChunkTriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(AsyncChunkTriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_col):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_col.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
left_proj_act, work = gather_async(left_proj_act.contiguous(), dim=2)
g = torch.sigmoid(self.output_gate(Z))
right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
left_proj_act = gather_async_opp(left_proj_act, work, dim=2)
p = torch.matmul(permute_final_dims(left_proj_act, (2, 1, 0)), right_proj_act)
ab = permute_final_dims(p, (1, 2, 0))
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw.shape[2]
chunk_size = CHUNK_SIZE * 32
world_size = get_world_size()
rank = get_rank()
output = torch.empty_like(Z_raw)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
zi = Z_raw[:, :, i:i + chunk_size, :]
zi = self.layernorm1(zi)
gi = torch.sigmoid(self.left_right_gate(zi))
i_left_right_proj_act = self.left_right_projection(zi)
i_left_right_proj_act = Z_mask_col[:, :, i:i + chunk_size].unsqueeze(-1) * i_left_right_proj_act
i_left_right_proj_act *= gi
_, right_proj_act = i_left_right_proj_act.chunk(2, dim=-1)
right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
for j in range(0, para_dim, chunk_size):
zj = Z_raw[:, :, j:j + chunk_size, :]
zj = self.layernorm1(zj)
gj = torch.sigmoid(self.left_right_gate(zj))
j_left_right_proj_act = self.left_right_projection(zj)
j_left_right_proj_act = Z_mask_col[:, :, j:j + chunk_size].unsqueeze(-1) * j_left_right_proj_act
j_left_right_proj_act *= gj
left_proj_act, _ = j_left_right_proj_act.chunk(2, dim=-1)
left_proj_act = left_proj_act.contiguous()
work = None
left_proj_act_tmp = torch.empty_like(left_proj_act)
for k in range(0, world_size):
if world_size > 1:
if work:
broadcast_async_opp(work) # collect last broadcast
if k != rank:
left_proj_act_rec = left_proj_act_tmp.clone()
else: # init first broadcast
if k == rank:
broadcast_sync(k, left_proj_act, host=True)
else:
left_proj_act_tmp = broadcast_sync(k, left_proj_act, host=False)
left_proj_act_rec = left_proj_act_tmp.clone()
if k + 1 != world_size: # launch next broadcast
if k + 1 == rank:
work = broadcast_async(k + 1, left_proj_act, host=True)
else:
work = broadcast_async(k + 1, left_proj_act_tmp, host=False)
if k == rank: # broadcast self proj_act
# left: [seq,chunkj,dim] => [dim,chunkj,seq]
# right: [seq,chunki,dim] => [dim,seq,chunki]
# p: [dim,chunkj,chunki] => [chunkj,chunki,dim]
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 1, 0)),
right_proj_act
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, j_global:min(j_global + chunk_size, para_dim * (k + 1)), i:i + chunk_size, :] = p
else: # receive others broadcast
p = torch.matmul(
permute_final_dims(left_proj_act_rec, (2, 1, 0)),
right_proj_act
)
p = permute_final_dims(p, (1, 2, 0))
j_global = para_dim * k + j
output[:, j_global:min(j_global + chunk_size, para_dim * (k + 1)), i:i + chunk_size, :] = p
dropout_mask = torch.ones_like(Z_raw[:, 0:1, :, :], device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, Z_raw.shape[1], chunk_size):
z_raw = Z_raw[:, i:i + chunk_size, :, :]
g = torch.sigmoid(self.output_gate(self.layernorm1(z_raw)))
z = output[:, i:i + chunk_size, :, :]
z = self.output_projection(self.layernorm2(z))
z = bias_ele_dropout_residual(z,
self.output_bias,
g,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = z
return output
class ChunkTriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(ChunkTriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
# _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
# std=1.0 / math.sqrt(d_pair))
# self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
chunk_size = CHUNK_SIZE
para_dim = Z_raw.shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw.shape[0], Z_raw.shape[1], Z_raw.shape[2], self.n_head), device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, para_dim, chunk_size):
z = self.layernorm1(Z_raw[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(z[:, 0:1, :, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[:, i:i + chunk_size, :, :]
z = self.layernorm1(z_raw)
z_mask = Z_mask[:, i:i + chunk_size, :]
z = self.attention(z, z_mask, (b, -1))
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = z
return output
def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw[0])
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
Z_raw[0] = bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw[0],
prob=self.p_drop,
training=self.training)
return Z_raw
chunk_size = CHUNK_SIZE
para_dim = Z_raw[0].shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw[0].shape[0], Z_raw[0].shape[1], Z_raw[0].shape[2], self.n_head), device=Z_raw[0].device, dtype=Z_raw[0].dtype)
for i in range(0, para_dim, chunk_size):
z = self.layernorm1(Z_raw[0][:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
# output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(z[:, 0:1, :, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[0][:, i:i + chunk_size, :, :]
z = self.layernorm1(z_raw)
z_mask = Z_mask[:, i:i + chunk_size, :]
z = self.attention(z, z_mask, (b, -1))
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
Z_raw[0][:, i:i + chunk_size, :, :] = z
return Z_raw
class ChunkMSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(ChunkMSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z, M_mask):
if CHUNK_SIZE == None:
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training)
chunk_size = CHUNK_SIZE
para_dim_z = Z.shape[1]
para_dim_m = M_raw.shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z.shape[0], Z.shape[1], Z.shape[2], self.n_head), device=Z.device, dtype=Z.dtype)
for i in range(0, para_dim_z, chunk_size):
z = self.layernormZ(Z[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
output = torch.empty_like(M_raw)
dropout_mask = torch.ones_like(M_raw[:, 0:1, :, :], device=M_raw.device, dtype=M_raw.dtype)
for i in range(0, para_dim_m, chunk_size):
if DEBUG and i > 10:
break
m_raw = M_raw[:, i:i + chunk_size, :, :]
m = self.layernormM(m_raw)
m_mask = M_mask[:, i:i + chunk_size, :]
m = self.attention(m, m_mask, (b, -1))
m = bias_dropout_add(m,
self.out_bias,
dropout_mask,
m_raw,
prob=self.p_drop,
training=self.training)
output[:, i:i + chunk_size, :, :] = m
return output
def inplace(self, M_raw, Z, M_mask):
if CHUNK_SIZE == None:
## Input projections
M = self.layernormM(M_raw[0])
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
M_raw[0] = bias_dropout_add(M, self.out_bias, dropout_mask, M_raw[0], prob=self.p_drop, training=self.training)
return M_raw
chunk_size = CHUNK_SIZE
para_dim_z = Z.shape[1]
para_dim_m = M_raw[0].shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z.shape[0], Z.shape[1], Z.shape[2], self.n_head), device=Z.device, dtype=Z.dtype)
for i in range(0, para_dim_z, chunk_size):
z = self.layernormZ(Z[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
dropout_mask = torch.ones_like(M_raw[0][:, 0:1, :, :], device=M_raw[0].device, dtype=M_raw[0].dtype)
for i in range(0, para_dim_m, chunk_size):
if DEBUG and i > 10:
break
m_raw = M_raw[0][:, i:i + chunk_size, :, :]
m = self.layernormM(m_raw)
m_mask = M_mask[:, i:i + chunk_size, :]
m = self.attention(m, m_mask, (b, -1))
m = bias_dropout_add(m,
self.out_bias,
dropout_mask,
m_raw,
prob=self.p_drop,
training=self.training)
M_raw[0][:, i:i + chunk_size, :, :] = m
return M_raw
class ChunkTriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(ChunkTriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
para_dim = Z_raw.shape[2]
chunk_size = CHUNK_SIZE
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw.shape[0], Z_raw.shape[2], Z_raw.shape[1], self.n_head), device=Z_raw.device, dtype=Z_raw.dtype)
for i in range(0, para_dim, chunk_size):
z = Z_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
z = self.layernorm1(z)
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
output = torch.empty_like(Z_raw)
dropout_mask = torch.ones_like(Z_raw[:, :, 0:1, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[:, :, i:i + chunk_size, :]
z = self.layernorm1(z_raw.transpose(-2, -3))
z_mask = Z_mask[:, :, i:i + chunk_size].transpose(-1, -2)
z = self.attention(z, z_mask, (b, -1)).transpose(-2, -3)
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
output[:, :, i:i + chunk_size, :] = z
return output
def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = Z_raw[0].transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
Z_raw[0] = bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw[0],
prob=self.p_drop,
training=self.training)
return Z_raw
para_dim = Z_raw[0].shape[2]
chunk_size = CHUNK_SIZE
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z_raw[0].shape[0], Z_raw[0].shape[2], Z_raw[0].shape[1], self.n_head), device=Z_raw[0].device, dtype=Z_raw[0].dtype)
for i in range(0, para_dim, chunk_size):
z = Z_raw[0][:, :, i:i + chunk_size, :].transpose(-2, -3)
z = self.layernorm1(z)
b[:, i:i + chunk_size, :, :] = self.linear_b(z)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
dropout_mask = torch.ones_like(Z_raw[0][:, :, 0:1, :], device=z.device, dtype=z.dtype)
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
z_raw = Z_raw[0][:, :, i:i + chunk_size, :]
z = self.layernorm1(z_raw.transpose(-2, -3))
z_mask = Z_mask[:, :, i:i + chunk_size].transpose(-1, -2)
z = self.attention(z, z_mask, (b, -1)).transpose(-2, -3)
z = bias_dropout_add(z,
self.out_bias,
dropout_mask,
z_raw,
prob=self.p_drop,
training=self.training)
Z_raw[0][:, :, i:i + chunk_size, :] = z
return Z_raw
class ChunkMSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8):
super(ChunkMSAColumnGlobalAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.global_attention = GlobalAttention(
qkv_dim=d_node, c=c, n_head=n_head, out_dim=d_node
)
def forward(self, M_raw, M_mask):
if CHUNK_SIZE is None:
m = self.layernormM(M_raw.transpose(-2, -3))
m = self.global_attention(m, M_mask.transpose(-1, -2))
m = m.transpose(-2, -3)
M_raw = M_raw + m
else:
chunk_size = CHUNK_SIZE
para_dim = M_raw.shape[2]
for i in range(0, para_dim, chunk_size):
m = M_raw[:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
m = self.global_attention(m, m_mask)
m = m.transpose(-2, -3)
M_raw[:, :, i:i + chunk_size, :] += m
return M_raw
def inplace(self, M_raw, M_mask):
para_dim = M_raw[0].shape[2]
if CHUNK_SIZE is None:
chunk_size = para_dim
else:
chunk_size = CHUNK_SIZE
for i in range(0, para_dim, chunk_size):
if DEBUG and i > 10:
break
m = M_raw[0][:, :, i:i + chunk_size, :].transpose(-2, -3)
m = self.layernormM(m)
m_mask = M_mask[:, :, i:i + chunk_size].transpose(-1, -2)
m = self.global_attention(m, m_mask)
m = m.transpose(-2, -3)
M_raw[0][:, :, i:i + chunk_size, :] += m
return M_raw
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
if CHUNK_SIZE == None:
d = self.linear(d)
z = d + self.layer_norm_z(z)
else:
chunk_size = CHUNK_SIZE * 48
para_dim = d.shape[1]
for i in range(0, para_dim, chunk_size):
di = self.linear(d[i:i + chunk_size, :, :])
z[i:i + chunk_size, :, :] = di + self.layer_norm_z(z[i:i + chunk_size, :, :])
return m_update, z
class GlobalAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim):
super(GlobalAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.scaling = self.c ** (-0.5)
self.eps = 1e-10
self.inf = 1e9
self.to_q = Linear(qkv_dim, c * self.n_head, use_bias=False)
self.to_kv = Linear(qkv_dim, 2 * c, initializer="linear", use_bias=False)
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(
qkv_dim, n_head * c, initializer="zero", use_bias=False
)
self.o_linear = Linear(n_head * c, out_dim, initializer="zero")
def forward(self, m, mask):
if CHUNK_SIZE == None:
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
)
q = q * self.scaling
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
m = self.o_linear(weighted_avg)
else:
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax : ax + chunk_size, :, :]
mask_part = mask[:, ax : ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps
)
q = q * self.scaling
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = fused_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
output.append(self.o_linear(weighted_avg))
m = torch.cat(output, dim=1)
return m
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
ri = ri.type(tf_emb_i.dtype)
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
pair_emb = d[..., None] - reshaped_bins
pair_emb = torch.argmin(torch.abs(pair_emb), dim=-1)
pair_emb = nn.functional.one_hot(pair_emb, num_classes=len(boundaries)).float().type(ri.dtype)
pair_emb = self.linear_relpos(pair_emb)
pair_emb += tf_emb_i[..., None, :]
pair_emb += tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional, List
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.nn.primitives import Attention
from fastfold.utils.checkpointing import checkpoint_blocks
from fastfold.utils.tensor_utils import chunk_layer, permute_final_dims
from fastfold.model.fastnn.ops import (ChunkTransition, LayerNorm,
ChunkTriangleAttentionStartingNode, ChunkTriangleAttentionEndingNode,
AsyncChunkTriangleMultiplicationOutgoing, AsyncChunkTriangleMultiplicationIncoming)
from fastfold.distributed.comm import gather, scatter, col_to_row, row_to_col, scatter
class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
c_t:
Template embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(TemplatePointwiseAttention, self).__init__()
self.c_t = c_t
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.mha = Attention(
self.c_z,
self.c_t,
self.c_t,
self.c_hidden,
self.no_heads,
gating=False
)
def _chunk(self,
z: torch.Tensor,
t: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
"kv_x": t,
"biases": biases,
}
return chunk_layer(
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if template_mask is None:
template_mask = t.new_ones(t.shape[:-3])
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
out = torch.empty_like(z)
mask = torch.sum(template_mask.to(z.device)) > 0
for t0 in range(t.shape[0]):
for t1 in range(0, t.shape[1], chunk_size):
tt = t[t0, t1:t1 + chunk_size, :].unsqueeze(0)
tt = tt.to(z.device)
out[t0, t1:t1 + chunk_size, :] = self.mha(
q_x=z[t0, t1:t1 + chunk_size, :].unsqueeze(0),
kv_x=tt,
biases=biases
).squeeze(0) * mask
else:
out = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
out = out * (torch.sum(template_mask) > 0)
out = out.squeeze(-2)
return out
def inplace(self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if template_mask is None:
template_mask = t.new_ones(t.shape[:-3])
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
mask = torch.sum(template_mask.to(z.device)) > 0
for t0 in range(t.shape[0]):
for t1 in range(0, t.shape[1], chunk_size):
tt = t[t0, t1:t1 + chunk_size, :].unsqueeze(0)
tt = tt.to(z.device)
z[t0, t1:t1 + chunk_size, :] += self.mha(
q_x=z[t0, t1:t1 + chunk_size, :].unsqueeze(0),
kv_x=tt,
biases=biases
).squeeze(0) * mask
else:
t = self.mha(q_x=z, kv_x=t, biases=biases) * (torch.sum(template_mask) > 0)
# [*, N_res, N_res, C_z]
z += t
z = z.squeeze(-2)
return z
class TemplatePairBlock(nn.Module):
def __init__(
self,
c_t: int,
c_hidden_tri_att: int,
c_hidden_tri_mul: int,
no_heads: int,
pair_transition_n: int,
dropout_rate: float,
inf: float,
first_block: bool,
last_block: bool,
**kwargs,
):
super(TemplatePairBlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.c_t = c_t
self.c_hidden_tri_att = c_hidden_tri_att
self.c_hidden_tri_mul = c_hidden_tri_mul
self.n_head = no_heads
self.p_drop = dropout_rate
self.hidden_c = int(c_t / self.n_head)
self.TriangleMultiplicationOutgoing = AsyncChunkTriangleMultiplicationOutgoing(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
)
self.TriangleMultiplicationIncoming = AsyncChunkTriangleMultiplicationIncoming(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
)
self.TriangleAttentionStartingNode = ChunkTriangleAttentionStartingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
)
self.TriangleAttentionEndingNode = ChunkTriangleAttentionEndingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
)
self.PairTransition = ChunkTransition(d=self.c_t, n=pair_transition_n)
def forward(
self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
):
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
z = scatter(z, dim=1)
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_mask_row = scatter(mask, dim=1)
single_mask_col = scatter(mask, dim=2)
z = self.TriangleAttentionStartingNode(z, single_mask_row)
z = row_to_col(z)
z = self.TriangleAttentionEndingNode(z, single_mask_col)
z = col_to_row(z)
z = self.TriangleMultiplicationOutgoing(z, single_mask_row)
z = row_to_col(z)
z = self.TriangleMultiplicationIncoming(z, single_mask_col)
z = self.PairTransition(z)
z = col_to_row(z)
# z = torch.cat(single_templates, dim=-4)
if self.last_block:
z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :]
return z
def inplace(
self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
):
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
z[0] = scatter(z[0], dim=1)
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_mask_row = scatter(mask, dim=1)
single_mask_col = scatter(mask, dim=2)
z = self.TriangleAttentionStartingNode.inplace(z, single_mask_row)
z[0] = row_to_col(z[0])
z = self.TriangleAttentionEndingNode.inplace(z, single_mask_col)
z[0] = col_to_row(z[0])
z[0] = self.TriangleMultiplicationOutgoing(z[0], single_mask_row)
z[0] = row_to_col(z[0])
z[0] = self.TriangleMultiplicationIncoming(z[0], single_mask_col)
z = self.PairTransition.inplace(z)
z[0] = col_to_row(z[0])
# z = torch.cat(single_templates, dim=-4)
if self.last_block:
z[0] = gather(z[0], dim=1)
z[0] = z[0][:, :-padding_size, :-padding_size, :]
return z
class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(
self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_blocks,
no_heads,
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
inf=1e9,
**kwargs,
):
"""
Args:
c_t:
Template embedding channel dimension
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
Hidden dimension for triangular multiplication
no_blocks:
Number of blocks in the stack
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
Dropout rate used throughout the stack
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
"""
super(TemplatePairStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = TemplatePairBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
)
self.blocks.append(block)
self.layer_norm = LayerNorm(c_t)
def forward(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
if not self.training:
for i in range(0, t.shape[0]):
t[i] = self.layer_norm(t[i])
else:
t = self.layer_norm(t)
return t
def inplace(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t[0].shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b.inplace,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
for i in range(0, t[0].shape[0]):
t[0][i] = self.layer_norm(t[0][i].to(mask.device)).to(t[0].device)
return t
from fastfold.distributed.comm_async import gather_async
import torch
import torch.nn as nn
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.fastnn.ops import (Linear, SelfAttention, ChunkTransition,
ChunkTriangleAttentionStartingNode,
ChunkTriangleAttentionEndingNode,
AsyncChunkTriangleMultiplicationOutgoing,
AsyncChunkTriangleMultiplicationIncoming)
from fastfold.distributed.comm_async import gather_async_opp, gather_async
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class TriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_row):
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_row.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
# right_proj_act = gather(right_proj_act.contiguous(), dim=1)
right_proj_act, work = gather_async(right_proj_act.contiguous(), dim=1)
g = torch.sigmoid(self.output_gate(Z))
left_proj_act = permute_final_dims(left_proj_act, (2, 0, 1))
right_proj_act = gather_async_opp(right_proj_act, work, dim=1)
p = torch.matmul(
left_proj_act,
permute_final_dims(right_proj_act, (2, 1, 0)),
)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_right_projection = Linear(d_pair, 2 * c)
self.left_right_gate = Linear(d_pair, 2 * c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw, Z_mask_col):
Z = self.layernorm1(Z_raw)
left_right_proj_act = self.left_right_projection(Z)
left_right_proj_act = Z_mask_col.unsqueeze(-1) * left_right_proj_act
left_right_proj_act *= torch.sigmoid(self.left_right_gate(Z))
left_proj_act, right_proj_act = left_right_proj_act.chunk(2, dim=-1)
left_proj_act, work = gather_async(left_proj_act.contiguous(), dim=2)
g = torch.sigmoid(self.output_gate(Z))
right_proj_act = permute_final_dims(right_proj_act, (2, 0, 1))
left_proj_act = gather_async_opp(left_proj_act, work, dim=2)
p = torch.matmul(permute_final_dims(left_proj_act, (2, 1, 0)), right_proj_act)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
# _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
# std=1.0 / math.sqrt(d_pair))
# self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = self.layernorm1(Z_raw)
b = self.linear_b(Z)
# b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (Z_mask - 1.))[:, :, None, None, :]
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
# _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
# std=1.0 / math.sqrt(d_pair))
# self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.linear_b = Linear(d_pair, n_head, initializer='linear', use_bias=False)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw, Z_mask):
Z = Z_raw.transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
b = self.linear_b(Z)
# b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
# b = rearrange(b, 'b q k h -> b h q k')
# padding_bias = (1e9 * (Z_mask - 1.))[:, :, None, None, :]
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class PairCore(nn.Module):
def __init__(self, d_pair, p_drop=0.25):
super(PairCore, self).__init__()
self.d_pair = d_pair
self.n_head = 4
self.hidden_c = int(d_pair / self.n_head)
self.TriangleMultiplicationOutgoing = AsyncChunkTriangleMultiplicationOutgoing(d_pair,
p_drop=p_drop,
c=d_pair)
self.TriangleMultiplicationIncoming = AsyncChunkTriangleMultiplicationIncoming(d_pair,
p_drop=p_drop,
c=d_pair)
self.TriangleAttentionStartingNode = ChunkTriangleAttentionStartingNode(d_pair,
p_drop=p_drop,
c=self.hidden_c,
n_head=self.n_head)
self.TriangleAttentionEndingNode = ChunkTriangleAttentionEndingNode(d_pair,
p_drop=p_drop,
c=self.hidden_c,
n_head=self.n_head)
self.PairTransition = ChunkTransition(d=d_pair)
def forward(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1)
pair_mask_col = scatter(pair_mask, dim=2)
pair = self.TriangleMultiplicationOutgoing(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleMultiplicationIncoming(pair, pair_mask_col)
pair = col_to_row(pair)
pair = self.TriangleAttentionStartingNode(pair, pair_mask_row)
pair = row_to_col(pair)
pair = self.TriangleAttentionEndingNode(pair, pair_mask_col)
pair = self.PairTransition(pair)
pair = col_to_row(pair)
return pair
def inplace(self, pair, pair_mask):
pair_mask_row = scatter(pair_mask, dim=1)
pair_mask_col = scatter(pair_mask, dim=2)
pair[0] = self.TriangleMultiplicationOutgoing(pair[0], pair_mask_row)
pair[0] = row_to_col(pair[0])
pair[0] = self.TriangleMultiplicationIncoming(pair[0], pair_mask_col)
pair[0] = col_to_row(pair[0])
pair = self.TriangleAttentionStartingNode.inplace(pair, pair_mask_row)
pair[0] = row_to_col(pair[0])
pair = self.TriangleAttentionEndingNode.inplace(pair, pair_mask_col)
pair = self.PairTransition.inplace(pair)
pair[0] = col_to_row(pair[0])
return pair
\ No newline at end of file
from .alphafold import AlphaFold
from .lr_scheduler import AlphaFoldLRScheduler
from .loss import AlphaFoldLoss
__all__ = ["AlphaFold", "AlphaFoldLRScheduler", "AlphaFoldLoss"]
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from fastfold.data import data_transforms_multimer
from fastfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from fastfold.model.nn.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateEmbedder,
ExtraMSAEmbedder,
)
from fastfold.model.nn.embedders_multimer import TemplateEmbedderMultimer, InputEmbedderMultimer
from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants
from fastfold.model.nn.structure_module import StructureModule
from fastfold.utils.tensor_utils import (
dict_multimap,
tensor_tree_map,
)
import fastfold.habana as habana
class AlphaFold(nn.Module):
"""
Alphafold 2.
Implements Algorithm 2 (but with training).
"""
def __init__(self, config):
"""
Args:
config:
A dict-like config object (like the one in config.py)
"""
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
# Main trunk + structure module
if self.globals.is_multimer:
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
)
self.template_embedder = TemplateEmbedderMultimer(
template_config,
)
else:
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
self.template_embedder = TemplateEmbedder(
template_config,
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
is_multimer=self.globals.is_multimer,
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
is_multimer=self.globals.is_multimer,
**config["evoformer_stack"],
)
self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
)
self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.template.use_unit_vector,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
# Primary output dictionary
outputs = {}
if habana.is_habana():
from habana.hpuhelper import hpu_perf
perf = hpu_perf("iteration", sync=False)
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
if habana.is_habana():
perf.checkahead("1: Initialize the MSA and pair representations")
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = (
self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
if not self.globals.is_multimer
else self.input_embedder(feats)
)
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev, _ = pseudo_beta_fn(feats["aatype"], x_prev, None)
x_prev = x_prev.to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev, z_prev = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if(not _recycle):
m_1_prev *= 0
z_prev *= 0
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev
# [*, N, N, C_z]
z += z_prev
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev
if habana.is_habana():
perf.checkahead("2: Embed the templates + merge with MSA/pair embeddings")
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
if self.globals.is_multimer:
asym_id = feats["asym_id"]
multichain_mask_2d = asym_id[..., None] == asym_id[..., None, :]
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
inplace=self.globals.inplace
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
else:
if self.globals.inplace:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size,
inplace=self.globals.inplace
)
z = template_embeds["template_pair_embedding"]
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size,
)
z = z + template_embeds["template_pair_embedding"]
if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_single_embedding"]],
dim=-3
)
# [*, S, N]
if(not self.globals.is_multimer):
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
del torsion_angles_mask
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
del template_feats, template_embeds
if habana.is_habana():
perf.checkahead("3: Embed extra MSA features + merge with pairwise embeddings")
if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z]
if not self.globals.inplace:
z = self.extra_msa_stack(
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
)
else:
extra_msa_feat = [extra_msa_feat]
z = [z]
z = self.extra_msa_stack.inplace(
extra_msa_feat,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat[0].dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z[0].dtype),
_mask_trans=self.config._mask_trans,
)[0]
del extra_msa_feat, extra_msa_fn
if habana.is_habana():
perf.checkahead("4: Run MSA + pair embeddings through the trunk of the network")
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if not self.globals.inplace:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
else:
m = [m]
z = [z]
m, z, s = self.evoformer.inplace(
m,
z,
msa_mask=msa_mask.to(dtype=m[0].dtype),
pair_mask=pair_mask.to(dtype=z[0].dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
m = m[0]
z = z[0]
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
if habana.is_habana():
perf.checkahead("5: Predict 3D structure")
outputs["sm"] = self.structure_module(
s,
z,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
if habana.is_habana():
perf.checkahead("6: stop iteration")
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
if habana.is_habana():
from habana.hpuhelper import hpu_perf
perf = hpu_perf(f"cycle {cycle_no+1}/{num_iters}")
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
m_1_prev,
z_prev,
x_prev,
_recycle=(num_iters > 1)
)
if habana.is_habana():
perf.checknow("cycle finish")
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
return outputs
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logging
import ml_collections
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
import fastfold.habana as habana
from fastfold.common import residue_constants
from fastfold.utils import feats
from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
masked_mean,
permute_final_dims,
batched_gather,
)
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits, dim=-1),
dim=-1,
)
return loss
def sigmoid_cross_entropy(logits, labels):
log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
def torsion_angle_loss(
a, # [*, N, 7, 2]
a_gt, # [*, N, 7, 2]
a_alt_gt, # [*, N, 7, 2]
):
# [*, N, 7]
norm = torch.norm(a, dim=-1)
# [*, N, 7, 2]
a = a / norm.unsqueeze(-1)
# [*, N, 7]
diff_norm_gt = torch.norm(a - a_gt, dim=-1)
diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
# [*]
l_torsion = torch.mean(min_diff, dim=(-1, -2))
l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
an_weight = 0.02
return l_torsion + an_weight * l_angle_norm
def compute_fape(
pred_frames: Rigid,
target_frames: Rigid,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: Optional[float] = None,
eps=1e-8,
) -> torch.Tensor:
"""
Computes FAPE loss.
Args:
pred_frames:
[*, N_frames] Rigid object of predicted frames
target_frames:
[*, N_frames] Rigid object of ground truth frames
frames_mask:
[*, N_frames] binary mask for the frames
pred_positions:
[*, N_pts, 3] predicted atom positions
target_positions:
[*, N_pts, 3] ground truth positions
positions_mask:
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
# [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply(
pred_positions[..., None, :, :],
)
local_target_pos = target_frames.invert()[..., None].apply(
target_positions[..., None, :, :],
)
error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
)
if l1_clamp_distance is not None:
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error
def backbone_loss(
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = Rigid.from_tensor_7(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
)
if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
)
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
1 - use_clamped_fape
)
# Average over the batch dimension
fape_loss = torch.mean(fape_loss)
return fape_loss
def sidechain_loss(
sidechain_frames: torch.Tensor,
sidechain_atom_pos: torch.Tensor,
rigidgroups_gt_frames: torch.Tensor,
rigidgroups_alt_gt_frames: torch.Tensor,
rigidgroups_gt_exists: torch.Tensor,
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.0,
length_scale: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
*batch_dims, -1, 3
)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
fape = compute_fape(
sidechain_frames,
renamed_gt_frames,
rigidgroups_gt_exists,
sidechain_atom_pos,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
)
return fape
def fape_loss(
out: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
**{**batch, **config.sidechain},
)
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def supervised_chi_loss(
angles_sin_cos: torch.Tensor,
unnormalized_angles_sin_cos: torch.Tensor,
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_angles_sin_cos: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
"""
Implements Algorithm 27 (torsionAngleLoss)
Args:
angles_sin_cos:
[*, N, 7, 2] predicted angles
unnormalized_angles_sin_cos:
The same angles, but unnormalized
aatype:
[*, N] residue indices
seq_mask:
[*, N] sequence mask
chi_mask:
[*, N, 7] angle mask
chi_angles_sin_cos:
[*, N, 7, 2] ground truth angles
chi_weight:
Weight for the angle component of the loss
angle_norm_weight:
Weight for the normalization component of the loss
Returns:
[*] loss tensor
"""
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype,
residue_constants.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
true_chi = chi_angles_sin_cos[None]
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
sq_chi_error_shifted = torch.sum(
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
)
sq_chi_loss = masked_mean(
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
)
loss = chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.0)
norm_error = norm_error.permute(
*range(len(norm_error.shape))[1:-2], 0, -2, -1
)
angle_norm_loss = masked_mean(
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
)
loss = loss + angle_norm_weight * angle_norm_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
)
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_lddt_ca = torch.sum(
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return pred_lddt_ca * 100
def lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
dmat_true = torch.sqrt(
eps
+ torch.sum(
(
all_atom_positions[..., None, :]
- all_atom_positions[..., None, :, :]
)
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps
+ torch.sum(
(
all_atom_pred_pos[..., None, :]
- all_atom_pred_pos[..., None, :, :]
)
** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff)
* all_atom_mask
* permute_final_dims(all_atom_mask, (1, 0))
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
def lddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
score = score.detach()
if habana.is_habana():
bin_index = torch.floor(score * no_bins)
bin_index = torch.clamp(bin_index, max=(no_bins - 1)).float().long()
else:
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot(
bin_index, num_classes=no_bins
)
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = torch.sum(errors * all_atom_mask, dim=-1) / (
eps + torch.sum(all_atom_mask, dim=-1)
)
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def distogram_loss(
logits,
pseudo_beta,
pseudo_beta_mask,
min_bin=2.3125,
max_bin=21.6875,
no_bins=64,
eps=1e-6,
**kwargs,
):
boundaries = torch.linspace(
min_bin,
max_bin,
no_bins - 1,
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
keepdims=True,
)
true_bins = torch.sum(dists > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_bins, no_bins),
)
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
# FP16-friendly sum. Equivalent to:
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
mean = torch.sum(mean, dim=-1)
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
# Average over the batch dimensions
mean = torch.mean(mean)
return mean
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2
bin_centers = torch.cat(
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1],
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(
0, max_bin, steps=(no_bins - 1), device=logits.device
)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
(
predicted_aligned_error,
max_predicted_aligned_error,
) = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(
0, max_bin, steps=(no_bins - 1), device=logits.device
)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
def tm_loss(
logits,
final_affine_tensor,
backbone_rigid_tensor,
backbone_rigid_mask,
resolution,
max_bin=31,
no_bins=64,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps=1e-8,
**kwargs,
):
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
0, max_bin, steps=(no_bins - 1), device=logits.device
)
boundaries = boundaries ** 2
true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_bins, no_bins)
)
square_mask = (
backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
)
loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.5 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the loss dimension
loss = torch.mean(loss)
return loss
def between_residue_bond_loss(
pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
pred_atom_mask: torch.Tensor, # (*, N, 37/14)
residue_index: torch.Tensor, # (*, N)
aatype: torch.Tensor, # (*, N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0,
eps=1e-6,
) -> Dict[str, torch.Tensor]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
aatype: Amino acid type of given residue
tolerance_factor_soft: soft tolerance factor measured in standard deviations
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
by CA, C, N
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
by C, N, CA
* 'per_residue_loss_sum': sum of all losses for each residue
* 'per_residue_violation_mask': mask denoting all residues with violation
present.
"""
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
this_c_pos = pred_atom_positions[..., :-1, 2, :]
this_c_mask = pred_atom_mask[..., :-1, 2]
next_n_pos = pred_atom_positions[..., 1:, 0, :]
next_n_mask = pred_atom_mask[..., 1:, 0]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
# Compute loss for the C--N bond.
c_n_bond_length = torch.sqrt(
eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
gt_length = (
~next_is_proline
) * residue_constants.between_res_bond_length_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
1
]
gt_stddev = (
~next_is_proline
) * residue_constants.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
)
# Compute loss for the angles.
ca_c_bond_length = torch.sqrt(
eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
)
n_ca_bond_length = torch.sqrt(
eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
)
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = torch.sqrt(
eps + (ca_c_n_cos_angle - gt_angle) ** 2
)
ca_c_n_loss_per_residue = torch.nn.functional.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
ca_c_n_violation_mask = mask * (
ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = torch.sqrt(
eps + torch.square(c_n_ca_cos_angle - gt_angle)
)
c_n_ca_loss_per_residue = torch.nn.functional.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
)
per_residue_loss_sum = 0.5 * (
torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
)
# Compute hard violations.
violation_mask = torch.max(
torch.stack(
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
dim=-2,
),
dim=-2,
)[0]
violation_mask = torch.maximum(
torch.nn.functional.pad(violation_mask, (0, 1)),
torch.nn.functional.pad(violation_mask, (1, 0)),
)
return {
"c_n_loss_mean": c_n_loss,
"ca_c_n_loss_mean": ca_c_n_loss,
"c_n_ca_loss_mean": c_n_ca_loss,
"per_residue_loss_sum": per_residue_loss_sum,
"per_residue_violation_mask": violation_mask,
}
def between_residue_clash_loss(
atom14_pred_positions: torch.Tensor,
atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor,
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_atom_radius: Van der Waals radius for each atom.
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
fp_type = atom14_pred_positions.dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (
atom14_atom_exists[..., :, None, :, None]
* atom14_atom_exists[..., None, :, None, :]
).type(fp_type)
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask = dists_mask * (
residue_index[..., :, None, None, None]
< residue_index[..., None, :, None, None]
)
# Backbone C--N bond between subsequent residues is no clash.
if habana.is_habana():
c_one_hot = torch.tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device=residue_index.device)
else:
c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14
)
c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
)
c_one_hot = c_one_hot.type(fp_type)
if habana.is_habana():
n_one_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device=residue_index.device)
else:
n_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(0), num_classes=14
)
n_one_hot = n_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
)
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None]
c_n_bonds = (
neighbour_mask
* c_one_hot[..., None, None, :, None]
* n_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index("SG")
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape(
*((1,) * len(residue_index.shape[:-1])), 1
).squeeze(-1)
if habana.is_habana():
cys_sg_one_hot = torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], device=n_one_hot.device)
else:
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None]
* cys_sg_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (
atom14_atom_radius[..., :, None, :, None]
+ atom14_atom_radius[..., None, :, None, :]
)
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * torch.nn.functional.relu(
dists_lower_bound - overlap_tolerance_soft - dists
)
# Compute the mean loss.
# shape ()
mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1)
)
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard)
)
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)),
)
return {
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
}
def within_residue_violations(
atom14_pred_positions: torch.Tensor,
atom14_atom_exists: torch.Tensor,
atom14_dists_lower_bound: torch.Tensor,
atom14_dists_upper_bound: torch.Tensor,
tighten_bounds_for_loss=0.0,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions ([*, N, 14, 3]):
Predicted positions of atoms in global prediction frame.
atom14_atom_exists ([*, N, 14]):
Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound ([*, N, 14]):
Lower bound on allowed distances.
atom14_dists_upper_bound ([*, N, 14]):
Upper bound on allowed distances
tighten_bounds_for_loss ([*, N]):
Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum' ([*, N, 14]):
sum of all clash losses per atom, shape
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
"""
# Compute the mask for each residue.
dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
dists_masks = dists_masks.reshape(
*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
)
dists_masks = (
atom14_atom_exists[..., :, :, None]
* atom14_atom_exists[..., :, None, :]
* dists_masks
)
# Distance matrix
dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, :, None, :]
- atom14_pred_positions[..., :, None, :, :]
)
** 2,
dim=-1,
)
)
# Compute the loss.
dists_to_low_error = torch.nn.functional.relu(
atom14_dists_lower_bound + tighten_bounds_for_loss - dists
)
dists_to_high_error = torch.nn.functional.relu(
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
)
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
# Compute the violations mask.
violations = dists_masks * (
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
)
# Compute the per atom violations.
per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
)
return {
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
}
def find_structural_violations(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
violation_tolerance_factor: float,
clash_overlap_tolerance: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
aatype=batch["aatype"],
tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=violation_tolerance_factor,
)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance,
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor,
)
atom14_atom_exists = batch["atom14_atom_exists"]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["lower_bound"]
)[batch["aatype"]]
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["upper_bound"]
)[batch["aatype"]]
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0,
)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = torch.max(
torch.stack(
[
connection_violations["per_residue_violation_mask"],
torch.max(
between_residue_clashes["per_atom_clash_mask"], dim=-1
)[0],
torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
],
dim=-1,
),
dim=-1,
)[0]
return {
"between_residues": {
"bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
"angles_ca_c_n_loss_mean": connection_violations[
"ca_c_n_loss_mean"
], # ()
"angles_c_n_ca_loss_mean": connection_violations[
"c_n_ca_loss_mean"
], # ()
"connections_per_residue_loss_sum": connection_violations[
"per_residue_loss_sum"
], # (N)
"connections_per_residue_violation_mask": connection_violations[
"per_residue_violation_mask"
], # (N)
"clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
"clashes_per_atom_loss_sum": between_residue_clashes[
"per_atom_loss_sum"
], # (N, 14)
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
], # (N, 14)
},
"within_residues": {
"per_atom_loss_sum": residue_violations[
"per_atom_loss_sum"
], # (N, 14)
"per_atom_violations": residue_violations[
"per_atom_violations"
], # (N, 14),
},
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
}
def find_structural_violations_np(
batch: Dict[str, np.ndarray],
atom14_pred_positions: np.ndarray,
config: ml_collections.ConfigDict,
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions)
out = find_structural_violations(batch, atom14_pred_positions, **config)
to_np = lambda x: np.array(x)
np_out = tensor_tree_map(to_np, out)
return np_out
def extreme_ca_ca_distance_violations(
pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
pred_atom_mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
max_angstrom_tolerance=1.5,
eps=1e-6,
) -> torch.Tensor:
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
max_angstrom_tolerance: Maximum distance allowed to not count as violation.
Returns:
Fraction of consecutive CA-CA pairs with violation.
"""
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
ca_ca_distance = torch.sqrt(
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
)
violations = (
ca_ca_distance - residue_constants.ca_ca
) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
mean = masked_mean(mask, violations, -1)
return mean
def compute_violation_metrics(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor, # (N, 14, 3)
violations: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Compute several metrics to assess the structural violations."""
ret = {}
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
)
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret["violations_between_residue_bond"] = masked_mean(
batch["seq_mask"],
violations["between_residues"][
"connections_per_residue_violation_mask"
],
dim=-1,
)
ret["violations_between_residue_clash"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1,
)[0],
dim=-1,
)
ret["violations_within_residue"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations["within_residues"]["per_atom_violations"], dim=-1
)[0],
dim=-1,
)
ret["violations_per_residue"] = masked_mean(
mask=batch["seq_mask"],
value=violations["total_per_residue_violations_mask"],
dim=-1,
)
return ret
def compute_violation_metrics_np(
batch: Dict[str, np.ndarray],
atom14_pred_positions: np.ndarray,
violations: Dict[str, np.ndarray],
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions)
violations = tree_map(to_tensor, violations, np.ndarray)
out = compute_violation_metrics(batch, atom14_pred_positions, violations)
to_np = lambda x: np.array(x)
return tree_map(to_np, out, torch.Tensor)
def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ l_clash
)
return loss
def compute_renamed_ground_truth(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""
Find optimal renaming of ground truth based on the predicted positions.
Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
renaming swaps.
* atom14_gt_exists: Mask for which atoms exist in ground truth.
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
after renaming.
* atom14_atom_exists: Mask for whether each atom is part of the given
amino acid type.
atom14_pred_positions: Array of atom positions in global frame with shape
Returns:
Dictionary containing:
alt_naming_is_better: Array with 1.0 where alternative swap is better.
renamed_atom14_gt_positions: Array of optimal ground truth positions
after renaming swaps are performed.
renamed_atom14_gt_exists: Mask after renaming swap is performed.
"""
pred_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_gt_positions = batch["atom14_gt_positions"]
gt_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_gt_positions[..., None, :, None, :]
- atom14_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
alt_gt_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_alt_gt_positions[..., None, :, None, :]
- atom14_alt_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
atom14_gt_exists = batch["atom14_gt_exists"]
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
mask = (
atom14_gt_exists[..., None, :, None]
* atom14_atom_is_ambiguous[..., None, :, None]
* atom14_gt_exists[..., None, :, None, :]
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
)
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
fp_type = atom14_pred_positions.dtype
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
]
return {
"alt_naming_is_better": alt_naming_is_better,
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
}
def experimentally_resolved_loss(
logits: torch.Tensor,
atom37_atom_exists: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
min_resolution: float,
max_resolution: float,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1)
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
loss = torch.mean(loss)
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
Args:
logits: [*, N_seq, N_res, 23] predicted residue distribution
true_msa: [*, N_seq, N_res] true MSA
bert_mask: [*, N_seq, N_res] MSA mask
Returns:
Masked MSA loss
"""
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
)
# FP16-friendly averaging. Equivalent to:
# loss = (
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
loss = errors * bert_mask
loss = torch.sum(loss, dim=-1)
scale = 0.5
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = torch.mean(loss)
return loss
def compute_drmsd(structure_1, structure_2, mask=None):
if(mask is not None):
structure_1 = structure_1 * mask[..., None]
structure_2 = structure_2 * mask[..., None]
d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
d1 = d1 ** 2
d2 = d2 ** 2
d1 = torch.sqrt(torch.sum(d1, dim=-1))
d2 = torch.sqrt(torch.sum(d2, dim=-1))
drmsd = d1 - d2
drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def compute_drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return compute_drmsd(structure_1, structure_2, mask)
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch, _return_breakdown=False):
if "violation" not in out.keys():
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.violation,
)
if "renamed_atom14_gt_positions" not in out.keys():
batch.update(
compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
)
)
loss_fns = {
"distogram": lambda: distogram_loss(
logits=out["distogram_logits"],
**{**batch, **self.config.distogram},
),
"experimentally_resolved": lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved_logits"],
**{**batch, **self.config.experimentally_resolved},
),
"fape": lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt": lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
),
"masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
**{**batch, **self.config.masked_msa},
),
"supervised_chi": lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
**{**batch, **self.config.supervised_chi},
),
# Habana: TODO comment out below part to WA error in HMP
"violation": lambda: violation_loss(
out["violation"],
**batch,
),
}
if habana.is_habana():
del loss_fns["violation"]
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
)
cum_loss = 0.
losses = {}
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cum_loss.detach().clone()
# Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9.
seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown):
return cum_loss
return cum_loss, losses
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
""" Implements the learning rate schedule defined in the AlphaFold 2
supplement. A linear warmup is followed by a plateau at the maximum
learning rate and then exponential decay.
Note that the initial learning rate of the optimizer in question is
ignored; use this class' base_lr parameter to specify the starting
point of the warmup.
"""
def __init__(self,
optimizer,
last_epoch: int = -1,
verbose: bool = False,
base_lr: float = 0.,
max_lr: float = 0.001,
warmup_no_steps: int = 1000,
start_decay_after_n_steps: int = 50000,
decay_every_n_steps: int = 50000,
decay_factor: float = 0.95,
):
step_counts = {
"warmup_no_steps": warmup_no_steps,
"start_decay_after_n_steps": start_decay_after_n_steps,
}
for k,v in step_counts.items():
if(v < 0):
raise ValueError(f"{k} must be nonnegative")
if(warmup_no_steps > start_decay_after_n_steps):
raise ValueError(
"warmup_no_steps must not exceed start_decay_after_n_steps"
)
self.optimizer = optimizer
self.last_epoch = last_epoch
self.verbose = verbose
self.base_lr = base_lr
self.max_lr = max_lr
self.warmup_no_steps = warmup_no_steps
self.start_decay_after_n_steps = start_decay_after_n_steps
self.decay_every_n_steps = decay_every_n_steps
self.decay_factor = decay_factor
super(AlphaFoldLRScheduler, self).__init__(
optimizer,
last_epoch=last_epoch,
verbose=verbose,
)
def state_dict(self):
state_dict = {
k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
}
return state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
if(not self._get_lr_called_within_step):
raise RuntimeError(
"To get the last learning rate computed by the scheduler, use "
"get_last_lr()"
)
step_no = self.last_epoch
if(step_no <= self.warmup_no_steps):
lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
elif(step_no > self.start_decay_after_n_steps):
steps_since_decay = step_no - self.start_decay_after_n_steps
exp = (steps_since_decay // self.decay_every_n_steps) + 1
lr = self.max_lr * (self.decay_factor ** exp)
else: # plateau
lr = self.max_lr
return [lr for group in self.optimizer.param_groups]
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from functools import partialmethod
from typing import Union, List
class Dropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask
along a particular dimension.
If not in training mode, this module computes the identity function.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
"""
Args:
r:
Dropout rate
batch_dim:
Dimension(s) along which the dropout mask is shared
"""
super(Dropout, self).__init__()
self.r = r
if type(batch_dim) == int:
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape)
mask = self.dropout(mask)
x = x * mask
return x
class DropoutRowwise(Dropout):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
class DropoutColumnwise(Dropout):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple, Dict
from functools import partial
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import (
build_template_angle_feat,
build_template_pair_feat,
)
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x.to(dtype=self.linear_1.weight.dtype))
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x.to(dtype=self.linear.weight.dtype))
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment