Unverified Commit a80d5263 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

support alphafold v2.3 param (#128)

parent c3436dd1
......@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
......
......@@ -22,6 +22,16 @@ import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import permute_final_dims
_FUSED_TRIANGLE_MULTIPLICATION = False
def set_fused_triangle_multiplication():
global _FUSED_TRIANGLE_MULTIPLICATION
_FUSED_TRIANGLE_MULTIPLICATION = True
def is_fused_triangle_multiplication():
global _FUSED_TRIANGLE_MULTIPLICATION
return _FUSED_TRIANGLE_MULTIPLICATION
class TriangleMultiplicativeUpdate(nn.Module):
"""
......@@ -40,6 +50,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.c_hidden = c_hidden
self._outgoing = _outgoing
if _FUSED_TRIANGLE_MULTIPLICATION:
self.linear_p = Linear(self.c_z, 2 * self.c_hidden)
self.linear_g = Linear(self.c_z, 2 * self.c_hidden, init="gating")
self.linear_gate = Linear(self.c_z, self.c_z, init="gating")
else:
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
......@@ -77,13 +92,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
if _FUSED_TRIANGLE_MULTIPLICATION:
a = self.linear_p(z) * mask
a = self.sigmoid(self.linear_g(z))
a, b = a.chunk(2, dim=-1)
else:
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask
x = self._combine_projections(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
if _FUSED_TRIANGLE_MULTIPLICATION:
g = self.sigmoid(self.linear_gate(z))
else:
g = self.sigmoid(self.linear_g(z))
z = x * g
......
......@@ -20,6 +20,7 @@ import numpy as np
import torch
from typing import Union, List
from fastfold.model.nn.triangular_multiplicative_update import is_fused_triangle_multiplication
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
......@@ -188,6 +189,29 @@ def get_translation_dict(model, version):
"attention": AttentionGatedParams(tri_att.mha),
}
if is_fused_triangle_multiplication():
TriMulOutParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_p),
"gate": LinearParams(tri_mul.linear_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_gate),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_p),
"gate": LinearParams(tri_mul.linear_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_gate),
}
else:
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
......@@ -553,7 +577,7 @@ def get_translation_dict(model, version):
if "template_" in k:
evo_dict.pop(k)
if "_ptm" in version:
if "_ptm" in version or is_multimer:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
......
......@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from fastfold.model.fastnn.embedders import TemplateEmbedder
from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
from fastfold.model.nn.triangular_multiplicative_update import is_fused_triangle_multiplication
def copy_layernorm(model_fast, model_ori):
......@@ -72,12 +73,17 @@ def copy_transition(model_fast, model_ori):
def copy_triangle(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in)
copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out)
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_linear(model_fast.output_projection, model_ori.linear_z)
model_fast.output_bias.copy_(model_ori.linear_z.bias)
if is_fused_triangle_multiplication():
copy_linear(model_fast.output_gate, model_ori.linear_gate)
copy_linear(model_fast.left_right_projection, model_ori.linear_p)
copy_linear(model_fast.left_right_gate, model_ori.linear_g)
else:
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_left_right(model_fast.left_right_projection, model_ori.linear_a_p, model_ori.linear_b_p)
copy_left_right(model_fast.left_right_gate, model_ori.linear_a_g, model_ori.linear_b_g)
......
......@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.model.nn.triangular_multiplicative_update import set_fused_triangle_multiplication
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
......@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args):
config = model_config(args.model_name)
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
if "v3" in args.param_path:
set_fused_triangle_multiplication()
config.globals.inplace = args.inplace
config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment