Unverified Commit a80d5263 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

support alphafold v2.3 param (#128)

parent c3436dd1
...@@ -575,7 +575,7 @@ multimer_model_config_update = { ...@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": tm_enabled, "enabled": True,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
......
...@@ -22,6 +22,16 @@ import torch.nn as nn ...@@ -22,6 +22,16 @@ import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import permute_final_dims from fastfold.utils.tensor_utils import permute_final_dims
_FUSED_TRIANGLE_MULTIPLICATION = False
def set_fused_triangle_multiplication():
global _FUSED_TRIANGLE_MULTIPLICATION
_FUSED_TRIANGLE_MULTIPLICATION = True
def is_fused_triangle_multiplication():
global _FUSED_TRIANGLE_MULTIPLICATION
return _FUSED_TRIANGLE_MULTIPLICATION
class TriangleMultiplicativeUpdate(nn.Module): class TriangleMultiplicativeUpdate(nn.Module):
""" """
...@@ -40,6 +50,11 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -40,6 +50,11 @@ class TriangleMultiplicativeUpdate(nn.Module):
self.c_hidden = c_hidden self.c_hidden = c_hidden
self._outgoing = _outgoing self._outgoing = _outgoing
if _FUSED_TRIANGLE_MULTIPLICATION:
self.linear_p = Linear(self.c_z, 2 * self.c_hidden)
self.linear_g = Linear(self.c_z, 2 * self.c_hidden, init="gating")
self.linear_gate = Linear(self.c_z, self.c_z, init="gating")
else:
self.linear_a_p = Linear(self.c_z, self.c_hidden) self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden) self.linear_b_p = Linear(self.c_z, self.c_hidden)
...@@ -77,13 +92,24 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -77,13 +92,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z) z = self.layer_norm_in(z)
if _FUSED_TRIANGLE_MULTIPLICATION:
a = self.linear_p(z) * mask
a = self.sigmoid(self.linear_g(z))
a, b = a.chunk(2, dim=-1)
else:
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
a = a * mask a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask b = b * mask
x = self._combine_projections(a, b) x = self._combine_projections(a, b)
x = self.layer_norm_out(x) x = self.layer_norm_out(x)
x = self.linear_z(x) x = self.linear_z(x)
if _FUSED_TRIANGLE_MULTIPLICATION:
g = self.sigmoid(self.linear_gate(z))
else:
g = self.sigmoid(self.linear_g(z)) g = self.sigmoid(self.linear_g(z))
z = x * g z = x * g
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import torch import torch
from typing import Union, List from typing import Union, List
from fastfold.model.nn.triangular_multiplicative_update import is_fused_triangle_multiplication
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
...@@ -188,6 +189,29 @@ def get_translation_dict(model, version): ...@@ -188,6 +189,29 @@ def get_translation_dict(model, version):
"attention": AttentionGatedParams(tri_att.mha), "attention": AttentionGatedParams(tri_att.mha),
} }
if is_fused_triangle_multiplication():
TriMulOutParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_p),
"gate": LinearParams(tri_mul.linear_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_gate),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_p),
"gate": LinearParams(tri_mul.linear_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_gate),
}
else:
TriMulOutParams = lambda tri_mul: { TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p), "left_projection": LinearParams(tri_mul.linear_a_p),
...@@ -553,7 +577,7 @@ def get_translation_dict(model, version): ...@@ -553,7 +577,7 @@ def get_translation_dict(model, version):
if "template_" in k: if "template_" in k:
evo_dict.pop(k) evo_dict.pop(k)
if "_ptm" in version: if "_ptm" in version or is_multimer:
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
......
...@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack ...@@ -18,6 +18,7 @@ from fastfold.model.fastnn import EvoformerStack, ExtraMSAStack
from fastfold.model.fastnn.embedders import TemplateEmbedder from fastfold.model.fastnn.embedders import TemplateEmbedder
from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
from fastfold.model.nn.triangular_multiplicative_update import is_fused_triangle_multiplication
def copy_layernorm(model_fast, model_ori): def copy_layernorm(model_fast, model_ori):
...@@ -72,12 +73,17 @@ def copy_transition(model_fast, model_ori): ...@@ -72,12 +73,17 @@ def copy_transition(model_fast, model_ori):
def copy_triangle(model_fast, model_ori): def copy_triangle(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in) copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in)
copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out) copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out)
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_linear(model_fast.output_projection, model_ori.linear_z) copy_linear(model_fast.output_projection, model_ori.linear_z)
model_fast.output_bias.copy_(model_ori.linear_z.bias) model_fast.output_bias.copy_(model_ori.linear_z.bias)
if is_fused_triangle_multiplication():
copy_linear(model_fast.output_gate, model_ori.linear_gate)
copy_linear(model_fast.left_right_projection, model_ori.linear_p)
copy_linear(model_fast.left_right_gate, model_ori.linear_g)
else:
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_left_right(model_fast.left_right_projection, model_ori.linear_a_p, model_ori.linear_b_p) copy_left_right(model_fast.left_right_projection, model_ori.linear_a_p, model_ori.linear_b_p)
copy_left_right(model_fast.left_right_gate, model_ori.linear_a_g, model_ori.linear_b_g) copy_left_right(model_fast.left_right_gate, model_ori.linear_a_g, model_ori.linear_b_g)
......
...@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax ...@@ -34,6 +34,7 @@ import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants from fastfold.common import protein, residue_constants
from fastfold.config import model_config from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size from fastfold.model.fastnn import set_chunk_size
from fastfold.model.nn.triangular_multiplicative_update import set_fused_triangle_multiplication
from fastfold.data import data_pipeline, feature_pipeline, templates from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.data.tools import hhsearch, hmmsearch from fastfold.data.tools import hhsearch, hmmsearch
from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow from fastfold.workflow.template import FastFoldDataWorkFlow, FastFoldMultimerDataWorkFlow
...@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args): ...@@ -117,6 +118,10 @@ def inference_model(rank, world_size, result_q, batch, args):
config = model_config(args.model_name) config = model_config(args.model_name)
if args.chunk_size: if args.chunk_size:
config.globals.chunk_size = args.chunk_size config.globals.chunk_size = args.chunk_size
if "v3" in args.param_path:
set_fused_triangle_multiplication()
config.globals.inplace = args.inplace config.globals.inplace = args.inplace
config.globals.is_multimer = args.model_preset == 'multimer' config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config) model = AlphaFold(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment