"vscode:/vscode.git/clone" did not exist on "9bafef34bde99a6184c4d4c5af8bd434244a5ed9"
Unverified Commit b3af1957 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Multimer supports chunk and inplace (#77)

parent b3d4fcca
...@@ -13,12 +13,12 @@ dependencies: ...@@ -13,12 +13,12 @@ dependencies:
- requests==2.26.0 - requests==2.26.0
- scipy==1.7.1 - scipy==1.7.1
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10.0.2 - typing-extensions==4.3.0
- einops - einops
- colossalai
- ray==2.0.0 - ray==2.0.0
- pyarrow - pyarrow
- pandas - pandas
- --find-links https://release.colossalai.org colossalai
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113 - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113 - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
- --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1 - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1
......
...@@ -137,12 +137,17 @@ class EvoformerBlock(nn.Module): ...@@ -137,12 +137,17 @@ class EvoformerBlock(nn.Module):
z = self.pair_stack.inplace(z, pair_mask) z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2) m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else: else:
z = self.communication(m, msa_mask, z) # z = self.communication.inplace(m[0], msa_mask, z)
z_ori = z # z_ori = z[0].clone()
m, work = All_to_All_Async.apply(m, 1, 2) # m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack(z, pair_mask) # z = self.pair_stack.inplace(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) # m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask) # m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa_stack(m[0], z[0], msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block: if self.last_block:
m[0] = m[0].squeeze(0) m[0] = m[0].squeeze(0)
...@@ -288,12 +293,17 @@ class ExtraMSABlock(nn.Module): ...@@ -288,12 +293,17 @@ class ExtraMSABlock(nn.Module):
z = self.pair_stack.inplace(z, pair_mask) z = self.pair_stack.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2) m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else: else:
z = self.communication(m, msa_mask, z) # z = self.communication.inplace(m[0], msa_mask, z)
z_ori = z # z_ori = [z[0].clone()]
m, work = All_to_All_Async.apply(m, 1, 2) # m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z = self.pair_stack(z, pair_mask) # z = self.pair_stack.inplace(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) # m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask) # m = self.msa_stack.inplace(m, z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m = self.msa_stack.inplace(m, z, msa_mask)
z = self.pair_stack.inplace(z, pair_mask)
if self.last_block: if self.last_block:
......
...@@ -172,7 +172,7 @@ class ExtraMSAStack(nn.Module): ...@@ -172,7 +172,7 @@ class ExtraMSAStack(nn.Module):
def inplace(self, node, pair, node_mask): def inplace(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1) node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias.inplace(node, pair, node_mask_row) node = self.MSARowAttentionWithPairBias.inplace(node, pair[0], node_mask_row)
node[0] = row_to_col(node[0]) node[0] = row_to_col(node[0])
node_mask_col = scatter(node_mask, dim=2) node_mask_col = scatter(node_mask, dim=2)
......
...@@ -675,17 +675,18 @@ class ChunkTriangleAttentionStartingNode(nn.Module): ...@@ -675,17 +675,18 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
def inplace(self, Z_raw, Z_mask): def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
Z = self.layernorm1(Z_raw) Z = self.layernorm1(Z_raw[0])
b = self.linear_b(Z) b = self.linear_b(Z)
b, work = gather_async(b, dim=1) b, work = gather_async(b, dim=1)
Z = self.attention(Z, Z_mask, (b, work)) Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z, Z_raw[0] = bias_dropout_add(Z,
self.out_bias, self.out_bias,
dropout_mask, dropout_mask,
Z_raw, Z_raw[0],
prob=self.p_drop, prob=self.p_drop,
training=self.training) training=self.training)
return Z_raw
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
para_dim = Z_raw[0].shape[1] para_dim = Z_raw[0].shape[1]
...@@ -795,7 +796,7 @@ class ChunkMSARowAttentionWithPairBias(nn.Module): ...@@ -795,7 +796,7 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
## Input projections ## Input projections
M = self.layernormM(M_raw) M = self.layernormM(M_raw[0])
Z = self.layernormZ(Z) Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights) b = F.linear(Z, self.linear_b_weights)
b, work = gather_async(b, dim=1) b, work = gather_async(b, dim=1)
...@@ -803,15 +804,16 @@ class ChunkMSARowAttentionWithPairBias(nn.Module): ...@@ -803,15 +804,16 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :] # padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M = self.attention(M, M_mask, (b, work)) M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype) dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training) M_raw[0] = bias_dropout_add(M, self.out_bias, dropout_mask, M_raw[0], prob=self.p_drop, training=self.training)
return M_raw
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
para_dim_z = Z[0].shape[1] para_dim_z = Z.shape[1]
para_dim_m = M_raw[0].shape[1] para_dim_m = M_raw[0].shape[1]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it # z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b = torch.empty((Z[0].shape[0], Z[0].shape[1], Z[0].shape[2], self.n_head), device=Z[0].device, dtype=Z[0].dtype) b = torch.empty((Z.shape[0], Z.shape[1], Z.shape[2], self.n_head), device=Z.device, dtype=Z.dtype)
for i in range(0, para_dim_z, chunk_size): for i in range(0, para_dim_z, chunk_size):
z = self.layernormZ(Z[0][:, i:i + chunk_size, :, :]) z = self.layernormZ(Z[:, i:i + chunk_size, :, :])
b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights) b[:, i:i + chunk_size, :, :] = F.linear(z, self.linear_b_weights)
b, work = gather_async(b, dim=1) b, work = gather_async(b, dim=1)
b = gather_async_opp(b, work, dim=1) b = gather_async_opp(b, work, dim=1)
...@@ -910,7 +912,7 @@ class ChunkTriangleAttentionEndingNode(nn.Module): ...@@ -910,7 +912,7 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
def inplace(self, Z_raw, Z_mask): def inplace(self, Z_raw, Z_mask):
if CHUNK_SIZE == None: if CHUNK_SIZE == None:
Z = Z_raw.transpose(-2, -3) Z = Z_raw[0].transpose(-2, -3)
Z_mask = Z_mask.transpose(-1, -2) Z_mask = Z_mask.transpose(-1, -2)
Z = self.layernorm1(Z) Z = self.layernorm1(Z)
...@@ -919,12 +921,13 @@ class ChunkTriangleAttentionEndingNode(nn.Module): ...@@ -919,12 +921,13 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
Z = self.attention(Z, Z_mask, (b, work)) Z = self.attention(Z, Z_mask, (b, work))
Z = Z.transpose(-2, -3) Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype) dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z, Z_raw[0] = bias_dropout_add(Z,
self.out_bias, self.out_bias,
dropout_mask, dropout_mask,
Z_raw, Z_raw[0],
prob=self.p_drop, prob=self.p_drop,
training=self.training) training=self.training)
return Z_raw
para_dim = Z_raw[0].shape[2] para_dim = Z_raw[0].shape[2]
chunk_size = CHUNK_SIZE chunk_size = CHUNK_SIZE
......
...@@ -88,9 +88,11 @@ class AlphaFold(nn.Module): ...@@ -88,9 +88,11 @@ class AlphaFold(nn.Module):
**extra_msa_config["extra_msa_embedder"], **extra_msa_config["extra_msa_embedder"],
) )
self.extra_msa_stack = ExtraMSAStack( self.extra_msa_stack = ExtraMSAStack(
is_multimer=self.globals.is_multimer,
**extra_msa_config["extra_msa_stack"], **extra_msa_config["extra_msa_stack"],
) )
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
is_multimer=self.globals.is_multimer,
**config["evoformer_stack"], **config["evoformer_stack"],
) )
self.structure_module = StructureModule( self.structure_module = StructureModule(
...@@ -269,6 +271,7 @@ class AlphaFold(nn.Module): ...@@ -269,6 +271,7 @@ class AlphaFold(nn.Module):
no_batch_dims, no_batch_dims,
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d, multichain_mask_2d=multichain_mask_2d,
inplace=self.globals.inplace
) )
feats["template_torsion_angles_mask"] = ( feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"] template_embeds["template_mask"]
...@@ -302,12 +305,13 @@ class AlphaFold(nn.Module): ...@@ -302,12 +305,13 @@ class AlphaFold(nn.Module):
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
del torsion_angles_mask
else: else:
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]], [feats["msa_mask"], template_embeds["template_mask"]],
dim=-2, dim=-2,
) )
del template_feats, template_embeds, torsion_angles_mask del template_feats, template_embeds
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
...@@ -321,7 +325,7 @@ class AlphaFold(nn.Module): ...@@ -321,7 +325,7 @@ class AlphaFold(nn.Module):
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat) extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z] # [*, N, N, C_z]
if not self.globals.inplace or self.globals.is_multimer: if not self.globals.inplace:
z = self.extra_msa_stack( z = self.extra_msa_stack(
extra_msa_feat, extra_msa_feat,
z, z,
...@@ -347,7 +351,7 @@ class AlphaFold(nn.Module): ...@@ -347,7 +351,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m] # m: [*, S, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
# s: [*, N, C_s] # s: [*, N, C_s]
if not self.globals.inplace or self.globals.is_multimer: if not self.globals.inplace:
m, z, s = self.evoformer( m, z, s = self.evoformer(
m, m,
z, z,
......
...@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module): ...@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_mask = template_chi_mask[..., 0] template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder( template_features = self.template_single_embedder(
template_features template_features
) )
template_activations = torch.nn.functional.relu( template_features = torch.nn.functional.relu(
template_activations template_features
) )
template_activations = self.template_projector( template_features = self.template_projector(
template_activations, template_features,
) )
out["template_single_embedding"] = ( out["template_single_embedding"] = (
template_activations template_features
) )
out["template_mask"] = template_mask out["template_mask"] = template_mask
...@@ -296,6 +296,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -296,6 +296,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
inplace
): ):
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
...@@ -307,7 +308,6 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -307,7 +308,6 @@ class TemplateEmbedderMultimer(nn.Module):
) )
single_template_embeds = {} single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = ( template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"], single_template_feats["template_pseudo_beta"],
...@@ -361,17 +361,27 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -361,17 +361,27 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds, template_embeds,
) )
# [*, S_t, N, N, C_z] if not inplace:
t = self.template_pair_stack( # [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"], template_embeds["template_pair_embedding"] = self.template_pair_stack(
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), template_embeds["template_pair_embedding"],
chunk_size=chunk_size, padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
_mask_trans=False, chunk_size=chunk_size,
) _mask_trans=False,
)
else:
template_embeds["template_pair_embedding"] = [template_embeds["template_pair_embedding"]]
# [*, S_t, N, N, C_z]
template_embeds["template_pair_embedding"] = self.template_pair_stack.inplace(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)[0].to(z.device)
# [*, N, N, C_z] # [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ template_embeds["template_pair_embedding"] = torch.sum(template_embeds["template_pair_embedding"], dim=-4) / n_templ
t = torch.nn.functional.relu(t) template_embeds["template_pair_embedding"] = torch.nn.functional.relu(template_embeds["template_pair_embedding"])
t = self.linear_t(t) template_embeds["template_pair_embedding"] = self.linear_t(template_embeds["template_pair_embedding"])
template_embeds["template_pair_embedding"] = t
return template_embeds return template_embeds
...@@ -229,6 +229,7 @@ class EvoformerBlock(nn.Module): ...@@ -229,6 +229,7 @@ class EvoformerBlock(nn.Module):
pair_dropout: float, pair_dropout: float,
inf: float, inf: float,
eps: float, eps: float,
is_multimer: bool,
): ):
super(EvoformerBlock, self).__init__() super(EvoformerBlock, self).__init__()
...@@ -268,6 +269,7 @@ class EvoformerBlock(nn.Module): ...@@ -268,6 +269,7 @@ class EvoformerBlock(nn.Module):
c_z, c_z,
c_hidden_opm, c_hidden_opm,
) )
self.is_multimer = is_multimer
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
...@@ -315,6 +317,7 @@ class ExtraMSABlock(nn.Module): ...@@ -315,6 +317,7 @@ class ExtraMSABlock(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
ckpt: bool, ckpt: bool,
is_multimer: bool,
): ):
super(ExtraMSABlock, self).__init__() super(ExtraMSABlock, self).__init__()
...@@ -351,6 +354,7 @@ class ExtraMSABlock(nn.Module): ...@@ -351,6 +354,7 @@ class ExtraMSABlock(nn.Module):
inf=inf, inf=inf,
eps=eps, eps=eps,
) )
self.is_multimer = is_multimer
def forward(self, def forward(self,
m: torch.Tensor, m: torch.Tensor,
...@@ -415,6 +419,7 @@ class EvoformerStack(nn.Module): ...@@ -415,6 +419,7 @@ class EvoformerStack(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs, **kwargs,
): ):
""" """
...@@ -474,6 +479,7 @@ class EvoformerStack(nn.Module): ...@@ -474,6 +479,7 @@ class EvoformerStack(nn.Module):
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps,
is_multimer=is_multimer,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -610,6 +616,7 @@ class ExtraMSAStack(nn.Module): ...@@ -610,6 +616,7 @@ class ExtraMSAStack(nn.Module):
eps: float, eps: float,
ckpt: bool, ckpt: bool,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs, **kwargs,
): ):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
...@@ -632,6 +639,7 @@ class ExtraMSAStack(nn.Module): ...@@ -632,6 +639,7 @@ class ExtraMSAStack(nn.Module):
inf=inf, inf=inf,
eps=eps, eps=eps,
ckpt=ckpt, ckpt=ckpt,
is_multimer=is_multimer,
) )
self.blocks.append(block) self.blocks.append(block)
......
...@@ -228,10 +228,12 @@ def inject_evoformer(model): ...@@ -228,10 +228,12 @@ def inject_evoformer(model):
for block_id, ori_block in enumerate(model.evoformer.blocks): for block_id, ori_block in enumerate(model.evoformer.blocks):
c_m = ori_block.msa_att_row.c_in c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z c_z = ori_block.msa_att_row.c_z
is_multimer = ori_block.is_multimer
fastfold_block = EvoformerBlock(c_m=c_m, fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z, c_z=c_z,
first_block=(block_id == 0), first_block=(block_id == 0),
last_block=(block_id == len(model.evoformer.blocks) - 1) last_block=(block_id == len(model.evoformer.blocks) - 1),
is_multimer=is_multimer,
) )
copy_evoformer_para(fastfold_block, ori_block) copy_evoformer_para(fastfold_block, ori_block)
...@@ -249,11 +251,13 @@ def inject_extraMsaBlock(model): ...@@ -249,11 +251,13 @@ def inject_extraMsaBlock(model):
for block_id, ori_block in enumerate(model.extra_msa_stack.blocks): for block_id, ori_block in enumerate(model.extra_msa_stack.blocks):
c_m = ori_block.msa_att_row.c_in c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z c_z = ori_block.msa_att_row.c_z
is_multimer = ori_block.is_multimer
new_model_block = ExtraMSABlock( new_model_block = ExtraMSABlock(
c_m=c_m, c_m=c_m,
c_z=c_z, c_z=c_z,
first_block=(block_id == 0), first_block=(block_id == 0),
last_block=(block_id == len(model.extra_msa_stack.blocks) - 1), last_block=(block_id == len(model.extra_msa_stack.blocks) - 1),
is_multimer=is_multimer
) )
copy_extra_msa_para(new_model_block, ori_block) copy_extra_msa_para(new_model_block, ori_block)
......
...@@ -115,6 +115,7 @@ def inference_model(rank, world_size, result_q, batch, args): ...@@ -115,6 +115,7 @@ def inference_model(rank, world_size, result_q, batch, args):
if args.chunk_size: if args.chunk_size:
config.globals.chunk_size = args.chunk_size config.globals.chunk_size = args.chunk_size
config.globals.inplace = args.inplace config.globals.inplace = args.inplace
config.globals.is_multimer = args.model_preset == 'multimer'
model = AlphaFold(config) model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment