Commit 2faff451 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make template code use less memory during inference

parent 407d9924
...@@ -231,8 +231,7 @@ config = mlc.ConfigDict( ...@@ -231,8 +231,7 @@ config = mlc.ConfigDict(
# Recurring FieldReferences that can be changed globally here # Recurring FieldReferences that can be changed globally here
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"train_chunk_size": None, "chunk_size": chunk_size,
"eval_chunk_size": chunk_size,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
......
...@@ -210,7 +210,9 @@ class AlignmentRunner: ...@@ -210,7 +210,9 @@ class AlignmentRunner:
) )
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path, databases=[pdb70_database_path] binary_path=hhsearch_binary_path,
databases=[pdb70_database_path],
n_cpu=no_cpus,
) )
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
......
...@@ -106,7 +106,7 @@ class AlphaFold(nn.Module): ...@@ -106,7 +106,7 @@ class AlphaFold(nn.Module):
self.config = config self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim, chunk_size): def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
...@@ -146,18 +146,20 @@ class AlphaFold(nn.Module): ...@@ -146,18 +146,20 @@ class AlphaFold(nn.Module):
template_embeds, template_embeds,
) )
# [*, N, N, C_z] # [*, S_t, N, N, C_z]
t = self.template_pair_stack( t = self.template_pair_stack(
template_embeds["pair"], template_embeds["pair"],
pair_mask.unsqueeze(-3), pair_mask.unsqueeze(-3),
chunk_size=chunk_size, chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
# [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, t,
z, z,
template_mask=batch["template_mask"], template_mask=batch["template_mask"],
chunk_size=chunk_size, chunk_size=self.globals.chunk_size,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
...@@ -170,12 +172,6 @@ class AlphaFold(nn.Module): ...@@ -170,12 +172,6 @@ class AlphaFold(nn.Module):
return ret return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
# Establish constants
chunk_size = (
self.globals.train_chunk_size
if self.training else self.globals.eval_chunk_size
)
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
...@@ -251,7 +247,6 @@ class AlphaFold(nn.Module): ...@@ -251,7 +247,6 @@ class AlphaFold(nn.Module):
z, z,
pair_mask, pair_mask,
no_batch_dims, no_batch_dims,
chunk_size,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
...@@ -281,7 +276,7 @@ class AlphaFold(nn.Module): ...@@ -281,7 +276,7 @@ class AlphaFold(nn.Module):
a, a,
z, z,
msa_mask=feats["extra_msa_mask"], msa_mask=feats["extra_msa_mask"],
chunk_size=chunk_size, chunk_size=self.globals.chunk_size,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -295,7 +290,7 @@ class AlphaFold(nn.Module): ...@@ -295,7 +290,7 @@ class AlphaFold(nn.Module):
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
......
...@@ -174,17 +174,43 @@ class TemplatePairStackBlock(nn.Module): ...@@ -174,17 +174,43 @@ class TemplatePairStackBlock(nn.Module):
) )
def forward(self, z, mask, chunk_size, _mask_trans=True): def forward(self, z, mask, chunk_size, _mask_trans=True):
z = z + self.dropout_row( for templ_idx in range(z.shape[-4]):
self.tri_att_start(z, chunk_size=chunk_size, mask=mask) # Select a single template at a time
single = z[..., templ_idx:templ_idx+1, :, :, :]
single_mask = mask[..., templ_idx:templ_idx+1, :, :]
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
) )
z = z + self.dropout_col(
self.tri_att_end(z, chunk_size=chunk_size, mask=mask)
) )
z = z + self.dropout_row(self.tri_mul_out(z, mask=mask)) single = single + self.dropout_row(
z = z + self.dropout_row(self.tri_mul_in(z, mask=mask)) self.tri_mul_in(
z = z + self.pair_transition( single,
z, chunk_size=chunk_size, mask=mask if _mask_trans else None mask=single_mask
) )
)
single = single + self.pair_transition(
single,
chunk_size=chunk_size,
mask=single_mask if _mask_trans else None
)
z[..., templ_idx:templ_idx+1, :, :, :] = single
return z return z
...@@ -254,12 +280,17 @@ class TemplatePairStack(nn.Module): ...@@ -254,12 +280,17 @@ class TemplatePairStack(nn.Module):
""" """
Args: Args:
t: t:
[*, N_res, N_res, C_t] template embedding [*, N_templ, N_res, N_res, C_t] template embedding
mask: mask:
[*, N_res, N_res] mask [*, N_templ, N_res, N_res] mask
Returns: Returns:
[*, N_res, N_res, C_t] template embedding update [*, N_templ, N_res, N_res, C_t] template embedding update
""" """
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
(t,) = checkpoint_blocks( (t,) = checkpoint_blocks(
blocks=[ blocks=[
partial( partial(
......
...@@ -133,8 +133,8 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -133,8 +133,8 @@ class TestTemplatePairStack(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack( out_repro = model.template_pair_stack(
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None, chunk_size=None,
_mask_trans=False, _mask_trans=False,
).cpu() ).cpu()
...@@ -182,7 +182,6 @@ class Template(unittest.TestCase): ...@@ -182,7 +182,6 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
templ_dim=0, templ_dim=0,
chunk_size=None,
) )
out_repro = out_repro["template_pair_embedding"] out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment