Unverified Commit b3af1957 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Multimer supports chunk and inplace (#77)

parent b3d4fcca
......@@ -13,12 +13,12 @@ dependencies:
- requests==2.26.0
- scipy==1.7.1
- tqdm==4.62.2
- typing-extensions==3.10.0.2
- typing-extensions==4.3.0
- einops
- colossalai
- ray==2.0.0
- pyarrow
- pandas
- --find-links https://release.colossalai.org colossalai
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1
......
......@@ -137,12 +137,17 @@ class EvoformerBlock(nn.Module):
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa_stack(m[0], z[0], msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
m[0] = m[0].squeeze(0)
......@@ -288,12 +293,17 @@ class ExtraMSABlock(nn.Module):
z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = [z[0].clone()]
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block:
......
......@@ -172,7 +172,7 @@ class ExtraMSAStack(nn.Module):
def inplace(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias.inplace(node, pair, node_mask_row)
node = self.MSARowAttentionWithPairBias.inplace(node, pair[0], node_mask_row)
node[0] = row_to_col(node[0])
node_mask_col = scatter(node_mask, dim=2)
......
......@@ -675,17 +675,18 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw)
Z = self.layernorm1(Z_raw[0])
b = self.linear_b(Z)
b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
Z_raw[0] = bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
Z_raw[0],
prob=self.p_drop,
training=self.training)
return Z_raw
chunk_size = CHUNK_SIZE
para_dim = Z_raw[0].shape[1]
......@@ -795,7 +796,7 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
if CHUNK_SIZE == None:
## Input projections
M = self.layernormM(M_raw)
M = self.layernormM(M_raw[0])
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
......@@ -803,15 +804,16 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training)
M_raw[0] = bias_dropout_add(M, self.out_bias, dropout_mask, M_raw[0], prob=self.p_drop, training=self.training)
return M_raw
chunk_size = CHUNK_SIZE
para_dim_z = Z[0].shape[1]
para_dim_z = Z.shape[1]
para_dim_m = M_raw[0].shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z[0].shape[0], Z[0].shape[1], Z[0].shape[2], self.n_head), device=Z[0].device, dtype=Z[0].dtype)
b = torch.empty((Z.shape[0], Z.shape[1], Z.shape[2], self.n_head), device=Z.device, dtype=Z.dtype)
for i in range(0, para_dim_z, chunk_size):
z = self.layernormZ(Z[0][:, i:i + chunk_size, :, :])
z = self.layernormZ(Z[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights)
b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1)
......@@ -910,7 +912,7 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None:
Z = Z_raw.transpose(-2, -3)
Z = Z_raw[0].transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z)
......@@ -919,12 +921,13 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z,
Z_raw[0] = bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
Z_raw[0],
prob=self.p_drop,
training=self.training)
return Z_raw
para_dim = Z_raw[0].shape[2]
chunk_size = CHUNK_SIZE
......
......@@ -88,9 +88,11 @@ class AlphaFold(nn.Module):
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
is_multimer=self.globals.is_multimer,
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
is_multimer=self.globals.is_multimer,
**config["evoformer_stack"],
)
self.structure_module = StructureModule(
......@@ -269,6 +271,7 @@ class AlphaFold(nn.Module):
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
inplace=self.globals.inplace
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
......@@ -302,12 +305,13 @@ class AlphaFold(nn.Module):
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
del torsion_angles_mask
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
del template_feats, template_embeds, torsion_angles_mask
del template_feats, template_embeds
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
......@@ -321,7 +325,7 @@ class AlphaFold(nn.Module):
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z]
if not self.globals.inplace or self.globals.is_multimer:
if not self.globals.inplace:
z = self.extra_msa_stack(
extra_msa_feat,
z,
......@@ -347,7 +351,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if not self.globals.inplace or self.globals.is_multimer:
if not self.globals.inplace:
m, z, s = self.evoformer(
m,
z,
......
......@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder(
template_features = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
template_features = torch.nn.functional.relu(
template_features
)
template_activations = self.template_projector(
template_activations,
template_features = self.template_projector(
template_features,
)
out["template_single_embedding"] = (
template_activations
template_features
)
out["template_mask"] = template_mask
......@@ -296,6 +296,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim,
chunk_size,
multichain_mask_2d,
inplace
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
......@@ -307,7 +308,6 @@ class TemplateEmbedderMultimer(nn.Module):
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
......@@ -361,17 +361,27 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
if not inplace:
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ
template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"])
template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"])
return template_embeds
......@@ -229,6 +229,7 @@ class EvoformerBlock(nn.Module):
pair_dropout: float,
inf: float,
eps: float,
is_multimer: bool,
):
super(EvoformerBlock, self).__init__()
......@@ -268,6 +269,7 @@ class EvoformerBlock(nn.Module):
c_z,
c_hidden_opm,
)
self.is_multimer = is_multimer
def forward(self,
m: torch.Tensor,
......@@ -315,6 +317,7 @@ class ExtraMSABlock(nn.Module):
inf: float,
eps: float,
ckpt: bool,
is_multimer: bool,
):
super(ExtraMSABlock, self).__init__()
......@@ -351,6 +354,7 @@ class ExtraMSABlock(nn.Module):
inf=inf,
eps=eps,
)
self.is_multimer = is_multimer
def forward(self,
m: torch.Tensor,
......@@ -415,6 +419,7 @@ class EvoformerStack(nn.Module):
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
"""
......@@ -474,6 +479,7 @@ class EvoformerStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
is_multimer=is_multimer,
)
self.blocks.append(block)
......@@ -610,6 +616,7 @@ class ExtraMSAStack(nn.Module):
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
......@@ -632,6 +639,7 @@ class ExtraMSAStack(nn.Module):
inf=inf,
eps=eps,
ckpt=ckpt,
is_multimer=is_multimer,
)
self.blocks.append(block)
......
......@@ -228,10 +228,12 @@ def inject_evoformer(model):
for block_id, ori_block in enumerate(model.evoformer.blocks):
c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z
is_multimer = ori_block.is_multimer
fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.evoformer.blocks) - 1)
last_block=(block_id == len(model.evoformer.blocks) - 1),
is_multimer=is_multimer,
)
copy_evoformer_para(fastfold_block, ori_block)
......@@ -249,11 +251,13 @@ def inject_extraMsaBlock(model):
for block_id, ori_block in enumerate(model.extra_msa_stack.blocks):
c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z
is_multimer = ori_block.is_multimer
new_model_block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.extra_msa_stack.blocks) - 1),
is_multimer=is_multimer
)
copy_extra_msa_para(new_model_block, ori_block)
......
......@@ -115,6 +115,7 @@ def inference_model(rank, world_size, result_q, batch, args):
if args.chunk_size:
config.globals.chunk_size = args.chunk_size
config.globals.inplace = args.inplace
config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment