Unverified Commit 1efccb6c authored by LuGY's avatar LuGY Committed by GitHub
Browse files

add faster extraMSA and templateStack (#53)

* add faster extraMSA and templateStack

* add license
parent 369f3e70
from .msa import MSAStack
from .msa import MSAStack, ExtraMSAStack
from .ops import OutProductMean, set_chunk_size
from .triangle import PairStack
from .evoformer import Evoformer
from .blocks import EvoformerBlock, ExtraMSABlock, TemplatePairStackBlock
__all__ = ['MSAStack', 'OutProductMean', 'PairStack', 'Evoformer', 'set_chunk_size']
__all__ = ['MSAStack', 'ExtraMSAStack', 'OutProductMean', 'PairStack', 'Evoformer',
'set_chunk_size', 'EvoformerBlock', 'ExtraMSABlock', 'TemplatePairStackBlock']
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack, ExtraMSAStack
from fastfold.model.fastnn.ops import Transition
from fastfold.model.fastnn.triangle import TriangleAttentionEndingNode, TriangleAttentionStartingNode, \
TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing
from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
class EvoformerBlock(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool):
super(EvoformerBlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z
class ExtraMSABlock(nn.Module):
def __init__(
self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer=False
):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_cnt = msa_mask.size(-2)
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(
m, (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z = torch.nn.functional.pad(
z, (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
m = scatter(m, dim=1) if not self.is_multimer else scatter(m, dim=2)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(
msa_mask, (0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
pair_mask = torch.nn.functional.pad(
pair_mask, (0, seq_len_padding_size, 0, seq_len_padding_size)
)
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = z + self.communication(m, msa_mask)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = gather(m, dim=1) if not self.is_multimer else gather(m, dim=2)
z = gather(z, dim=1)
m = m[:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z = z[:, :-seq_len_padding_size, :-seq_len_padding_size, :]
m = m.squeeze(0)
z = z.squeeze(0)
return m, z
class TemplatePairStackBlock(nn.Module):
def __init__(
self,
c_t: int,
c_hidden_tri_att: int,
c_hidden_tri_mul: int,
no_heads: int,
pair_transition_n: int,
dropout_rate: float,
inf: float,
first_block: bool,
last_block: bool,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.c_t = c_t
self.c_hidden_tri_att = c_hidden_tri_att
self.c_hidden_tri_mul = c_hidden_tri_mul
self.n_head = no_heads
self.p_drop = dropout_rate
self.hidden_c = int(c_t / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_mul
)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(
self.c_t, p_drop=self.p_drop, c=self.c_hidden_tri_att, n_head=self.n_head
)
self.PairTransition = Transition(d=self.c_t, n=pair_transition_n)
def forward(
self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
):
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
z = scatter(z, dim=1)
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
single_templates = [t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)]
single_templates_masks = [m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2)
single = self.TriangleMultiplicationOutgoing(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleMultiplicationIncoming(single, single_mask_col)
single = col_to_row(single)
single = self.TriangleAttentionStartingNode(single, single_mask_row)
single = row_to_col(single)
single = self.TriangleAttentionEndingNode(single, single_mask_col)
single = self.PairTransition(single)
single = col_to_row(single)
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
if self.last_block:
z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :]
return z
\ No newline at end of file
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
......@@ -5,7 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.model.fastnn.ops import Transition, SelfAttention
from fastfold.model.fastnn.ops import Transition, SelfAttention, GlobalAttention
from fastfold.model.fastnn.kernel import bias_dropout_add
from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async
......@@ -81,6 +95,31 @@ class MSAColumnAttention(nn.Module):
return M_raw + M
class MSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8):
super(MSAColumnGlobalAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.global_attention = GlobalAttention(
qkv_dim=d_node, c=c, n_head=n_head, out_dim=d_node
)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
M = self.global_attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
......@@ -105,3 +144,27 @@ class MSAStack(nn.Module):
node = self.MSATransition(node)
return node
class ExtraMSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(
d_node=d_node, d_pair=d_pair, p_drop=p_drop, c=8
)
self.MSAColumnAttention = MSAColumnGlobalAttention(d_node=d_node, c=8)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -219,3 +233,71 @@ class SelfAttention(nn.Module):
output = torch.cat(output, dim=1)
return output
class GlobalAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim):
super(GlobalAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.scaling = self.c ** (-0.5)
self.eps = 1e-10
self.inf = 1e9
self.to_q = Linear(qkv_dim, c * self.n_head, use_bias=False)
self.to_kv = Linear(qkv_dim, 2 * c, initializer="linear", use_bias=False)
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(
qkv_dim, n_head * c, initializer="zero", use_bias=False
)
self.o_linear = Linear(n_head * c, out_dim, initializer="zero")
def forward(self, m, mask):
para_dim = m.shape[1]
chunk_size = CHUNK_SIZE
if CHUNK_SIZE == None:
chunk_size = para_dim
output = []
for ax in range(0, para_dim, chunk_size):
m_part = m[:, ax : ax + chunk_size, :, :]
mask_part = mask[:, ax : ax + chunk_size, :]
q = torch.sum(m_part * mask_part.unsqueeze(-1), dim=-2) / (
torch.sum(mask_part, dim=-1)[..., None] + self.eps
)
q = self.to_q(q)
q = q.view(q.shape[:-1] + (self.n_head, -1))
k, v = self.to_kv(m_part).chunk(2, dim=-1)
logits = torch.matmul(q, k.transpose(-1, -2))
weights = mask_softmax(logits, mask_part)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, "b1 b2 h d -> b1 b2 (h d)")
gate_values = self.gating_linear(m_part)
weighted_avg = bias_sigmod_ele(
gate_values, self.gating_bias, weighted_avg.unsqueeze(-2)
)
output.append(self.o_linear(weighted_avg))
m = torch.cat(output, dim=1)
return m
\ No newline at end of file
from typing import Tuple, Optional
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.distributed.comm import gather, scatter
class EvoformerBlock(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool):
super(EvoformerBlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
return m, z
from fastfold.model.fastnn import EvoformerBlock, ExtraMSABlock, TemplatePairStackBlock
def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias)
......@@ -85,6 +27,10 @@ def copy_linear(model_fast, model_ori):
model_fast.bias.copy_(model_ori.bias)
def copy_kv_linear(model_fast, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_k.weight, ori_v.weight), dim=0))
def copy_qkv_linear(model_fast, ori_q, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_q.weight, ori_k.weight, ori_v.weight), dim=0))
......@@ -131,7 +77,7 @@ def copy_triangle_att(model_fast, model_ori):
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_para(block_fast, block_ori):
def copy_evoformer_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM,
......@@ -179,7 +125,104 @@ def copy_para(block_fast, block_ori):
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition)
def inject_fastnn(model):
def copy_global_attention(model_fast, model_ori):
copy_linear(model_fast.to_q, model_ori.linear_q)
copy_kv_linear(model_fast.to_kv, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_extra_msa_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m,
)
copy_layernorm(
block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z,
)
copy_attention(
block_fast.msa_stack.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha,
)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight
)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias
)
# MSAColumnAttention
copy_layernorm(
block_fast.msa_stack.MSAColumnAttention.layernormM,
block_ori.msa_att_col.layer_norm_m,
)
copy_global_attention(
block_fast.msa_stack.MSAColumnAttention.global_attention,
block_ori.msa_att_col.global_attention,
)
# MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition)
# communication
comm_model = (
block_ori.core.outer_product_mean# if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm(block_fast.communication.layernormM, comm_model.layer_norm)
copy_linear(block_fast.communication.linear_a, comm_model.linear_1)
copy_linear(block_fast.communication.linear_b, comm_model.linear_2)
copy_linear(block_fast.communication.o_linear, comm_model.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(
block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle(
block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att(
block_fast.pair_stack.TriangleAttentionStartingNode,
block_ori.core.tri_att_start,
)
copy_triangle_att(
block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end
)
copy_transition(
block_fast.pair_stack.PairTransition, block_ori.core.pair_transition
)
def copy_template_pair_stack_para(block_fast, block_ori):
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.TriangleMultiplicationOutgoing, block_ori.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.TriangleMultiplicationIncoming, block_ori.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.TriangleAttentionStartingNode, block_ori.tri_att_start)
copy_triangle_att(block_fast.TriangleAttentionEndingNode, block_ori.tri_att_end)
copy_transition(block_fast.PairTransition, block_ori.pair_transition)
def inject_evoformer(model):
with torch.no_grad():
fastfold_blocks = nn.ModuleList()
for block_id, ori_block in enumerate(model.evoformer.blocks):
......@@ -188,13 +231,74 @@ def inject_fastnn(model):
fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.evoformer.blocks) -
1))
last_block=(block_id == len(model.evoformer.blocks) - 1)
)
copy_para(fastfold_block, ori_block)
copy_evoformer_para(fastfold_block, ori_block)
fastfold_blocks.append(fastfold_block)
model.evoformer.blocks = fastfold_blocks
return model
def inject_extraMsaBlock(model):
with torch.no_grad():
new_model_blocks = nn.ModuleList()
for block_id, ori_block in enumerate(model.extra_msa_stack.blocks):
c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z
new_model_block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.extra_msa_stack.blocks) - 1),
)
copy_extra_msa_para(new_model_block, ori_block)
if ori_block.training == False:
new_model_block.eval()
new_model_blocks.append(new_model_block)
model.extra_msa_stack.blocks = new_model_blocks
def inject_templatePairBlock(model):
with torch.no_grad():
target_module = model.template_pair_stack.blocks
fastfold_blocks = nn.ModuleList()
for block_id, ori_block in enumerate(target_module):
c_t = ori_block.c_t
c_hidden_tri_att = ori_block.c_hidden_tri_att
c_hidden_tri_mul = ori_block.c_hidden_tri_mul
no_heads = ori_block.no_heads
pair_transition_n = ori_block.pair_transition_n
dropout_rate = ori_block.dropout_rate
inf = ori_block.inf
fastfold_block = TemplatePairStackBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
first_block=(block_id == 0),
last_block=(block_id == len(target_module) - 1),
)
copy_template_pair_stack_para(fastfold_block, ori_block)
if ori_block.training == False:
fastfold_block.eval()
fastfold_blocks.append(fastfold_block)
model.template_pair_stack.blocks = fastfold_blocks
def inject_fastnn(model):
inject_evoformer(model)
inject_extraMsaBlock(model)
inject_templatePairBlock(model)
return model
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment