Commit e5d72d41 authored by zhuww's avatar zhuww
Browse files

infer longer sequences using memory optimization

parent 52919c63
......@@ -27,14 +27,19 @@ def _reduce(tensor: Tensor) -> Tensor:
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
def _split(tensor: Tensor, dim: int = -1, drop_unused=False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
split_size = divide(tensor.shape[dim], gpc.get_world_size(ParallelMode.TENSOR))
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)].contiguous()
rank = gpc.get_local_rank(ParallelMode.TENSOR)
if not drop_unused:
output = tensor_list[rank].contiguous()
else:
output = tensor_list[rank].contiguous().clone()
return output
......@@ -65,6 +70,55 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
return output
def _chunk_gather(tensor: Tensor, dim=-1, chunks=1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
world_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
tensor_list = []
for t in world_list:
tensor_list.extend(t.chunk(chunks, dim=1))
chunk_tensor = tensor.chunk(chunks, dim=1)
for i in range(chunks):
_chunk_list = [tensor_list[j*4+i] for j in range(4)]
_chunk_tensor = chunk_tensor[i]
dist.all_gather(list(_chunk_list),
_chunk_tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
torch.cuda.empty_cache()
else:
output_shape = list(tensor.shape)
output_shape[0] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
world_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=0)
tensor_list = []
for t in world_list:
tensor_list.extend(t.chunk(chunks, dim=0))
chunk_tensor = tensor.chunk(chunks, dim=0)
for i in range(chunks):
_chunk_list = [tensor_list[j*4+i] for j in range(4)]
_chunk_tensor = chunk_tensor[i]
dist.all_gather(list(_chunk_list),
_chunk_tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
torch.cuda.empty_cache()
return output
def copy(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Copy.apply(input)
......@@ -122,11 +176,14 @@ class Reduce(torch.autograd.Function):
return grad_output
def gather(input: Tensor, dim: int = -1) -> Tensor:
def gather(input: Tensor, dim: int = -1, chunks: int = None) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim)
else:
input = _gather(input, dim=dim)
if chunks is None:
input = _gather(input, dim=dim)
else:
input = _chunk_gather(input, dim=dim, chunks=chunks)
return input
......
......@@ -224,6 +224,7 @@ class TemplateEmbedder(nn.Module):
).to(t.device)
del tt, single_template_feats
torch.cuda.empty_cache()
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
......@@ -245,6 +246,7 @@ class TemplateEmbedder(nn.Module):
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size * 256 if chunk_size is not None else chunk_size,
)
torch.cuda.empty_cache()
ret = {}
ret["template_pair_embedding"] = z
......
......@@ -105,9 +105,13 @@ class Evoformer(nn.Module):
m[0] = torch.nn.functional.pad(m[0], (0, 0, 0, padding_size))
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
torch.cuda.empty_cache()
m[0] = scatter(m[0], dim=1)
z[0] = scatter(z[0], dim=1)
m[0] = scatter(m[0], dim=1, drop_unused=True)
torch.cuda.empty_cache()
z[0] = scatter(z[0], dim=1, drop_unused=True)
torch.cuda.empty_cache()
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
......@@ -137,9 +141,10 @@ class Evoformer(nn.Module):
if self.last_block:
m[0] = m[0].squeeze(0)
z[0] = z[0].squeeze(0)
torch.cuda.empty_cache()
m[0] = gather(m[0], dim=0)
z[0] = gather(z[0], dim=0)
m[0] = gather(m[0], dim=0, chunks=4)
z[0] = gather(z[0], dim=0, chunks=4)
m[0] = m[0][:, :-padding_size, :]
z[0] = z[0][:-padding_size, :-padding_size, :]
......
......@@ -217,7 +217,7 @@ class ExtraMSABlock(nn.Module):
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
......@@ -225,6 +225,7 @@ class ExtraMSABlock(nn.Module):
m = torch.nn.functional.pad(
m, (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
z = torch.nn.functional.pad(
z, (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
......@@ -286,6 +287,8 @@ class ExtraMSABlock(nn.Module):
seq_len = pair_mask.size(-1)
seq_cnt_padding_size = (int(seq_cnt / dap_size) + 1) * dap_size - seq_cnt
seq_len_padding_size = (int(seq_len / dap_size) + 1) * dap_size - seq_len
torch.cuda.empty_cache()
if self.first_block:
m[0] = m[0].unsqueeze(0)
......@@ -294,12 +297,16 @@ class ExtraMSABlock(nn.Module):
m[0] = torch.nn.functional.pad(
m[0], (0, 0, 0, seq_len_padding_size, 0, seq_cnt_padding_size)
)
torch.cuda.empty_cache()
z[0] = torch.nn.functional.pad(
z[0], (0, 0, 0, seq_len_padding_size, 0, seq_len_padding_size)
)
torch.cuda.empty_cache()
m[0] = scatter(m[0], dim=1) if not self.is_multimer else scatter(m[0], dim=2)
z[0] = scatter(z[0], dim=1)
m[0] = scatter(m[0], dim=1, drop_unused=True) if not self.is_multimer else scatter(m[0], dim=2)
torch.cuda.empty_cache()
z[0] = scatter(z[0], dim=1, drop_unused=True)
torch.cuda.empty_cache()
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
......@@ -332,8 +339,9 @@ class ExtraMSABlock(nn.Module):
if self.last_block:
m[0] = gather(m[0], dim=1) if not self.is_multimer else gather(m[0], dim=2)
z[0] = gather(z[0], dim=1)
m[0] = gather(m[0], dim=1, chunks=4) if not self.is_multimer else gather(m[0], dim=2)
torch.cuda.empty_cache()
z[0] = gather(z[0], dim=1, chunks=4)
m[0] = m[0][:, :-seq_cnt_padding_size, :-seq_len_padding_size, :]
z[0] = z[0][:, :-seq_len_padding_size, :-seq_len_padding_size, :]
......
......@@ -207,7 +207,9 @@ class OutProductMean(nn.Module):
norm = torch.einsum('bsid,bsjd->bijd', M_mask_col, M_mask) + 1e-3
right_act_all = gather_async_opp(right_act_all, work, dim=2)
torch.cuda.empty_cache()
right_act_all = M_mask * right_act_all
torch.cuda.empty_cache()
para_dim = left_act.shape[2]
chunk_size = CHUNK_SIZE
......@@ -1272,11 +1274,14 @@ class InputEmbedder(nn.Module):
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
pair_emb = d[..., None] - reshaped_bins
del d, reshaped_bins
torch.cuda.empty_cache()
pair_emb = torch.argmin(torch.abs(pair_emb), dim=-1)
pair_emb = nn.functional.one_hot(pair_emb, num_classes=len(boundaries)).float().type(ri.dtype)
pair_emb = self.linear_relpos(pair_emb)
pair_emb += tf_emb_i[..., None, :]
pair_emb += tf_emb_j[..., None, :, :]
torch.cuda.empty_cache()
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
......
......@@ -282,9 +282,11 @@ class TemplatePairBlock(nn.Module):
seq_length = mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
torch.cuda.empty_cache()
if self.first_block:
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
z[0] = scatter(z[0], dim=1)
z[0] = scatter(z[0], dim=1, drop_unused=True)
torch.cuda.empty_cache()
mask = torch.nn.functional.pad(mask, (0, padding_size, 0, padding_size))
......@@ -294,8 +296,8 @@ class TemplatePairBlock(nn.Module):
single = z[0][i].unsqueeze(-4).to(mask.device)
single_mask = mask[i].unsqueeze(-3)
single_mask_row = scatter(single_mask, dim=1)
single_mask_col = scatter(single_mask, dim=2)
single_mask_row = scatter(single_mask, dim=1, drop_unused=True)
single_mask_col = scatter(single_mask, dim=2, drop_unused=True)
single = self.TriangleAttentionStartingNode(single, single_mask_row)
single = row_to_col(single)
......@@ -307,6 +309,7 @@ class TemplatePairBlock(nn.Module):
single = self.PairTransition(single)
single = col_to_row(single)
z[0][i] = single.to(z[0].device)
# z = torch.cat(single_templates, dim=-4)
if self.last_block:
......
......@@ -169,7 +169,8 @@ class AlphaFold(nn.Module):
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
def iteration(self, feats, prevs, _recycle=True, no_iter=0):
torch.cuda.empty_cache()
# Primary output dictionary
outputs = {}
......@@ -203,7 +204,10 @@ class AlphaFold(nn.Module):
if not self.globals.is_multimer
else self.input_embedder(feats)
)
torch.cuda.empty_cache()
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
......@@ -251,6 +255,8 @@ class AlphaFold(nn.Module):
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev
torch.cuda.empty_cache()
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
......@@ -330,6 +336,8 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
torch.cuda.empty_cache()
# [*, N, N, C_z]
if not self.globals.inplace:
......@@ -353,6 +361,8 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans,
)[0]
del extra_msa_feat, extra_msa_fn
torch.cuda.empty_cache()
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
......@@ -380,36 +390,49 @@ class AlphaFold(nn.Module):
)
m = m[0]
z = z[0]
torch.cuda.empty_cache()
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
if no_iter == 3:
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
z = [z]
outputs_sm = self.structure_module(
s,
z,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
torch.cuda.empty_cache()
if no_iter == 3:
m_1_prev, z_prev, x_prev = None, None, None
outputs["sm"] = outputs_sm
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
else:
# Save embeddings for use during the next recycling iteration
# [*, N, N, C_z]
z_prev = z
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
# [*, N, N, C_z]
z_prev = z
return outputs, m_1_prev, z_prev, x_prev
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
if no_iter != 3:
return None, m_1_prev, z_prev, x_prev
else:
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
......@@ -482,6 +505,7 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
......@@ -506,11 +530,14 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
m_1_prev,
z_prev,
x_prev,
_recycle=(num_iters > 1)
prevs,
_recycle=(num_iters > 1),
no_iter=cycle_no
)
if cycle_no != 3:
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -389,6 +389,8 @@ class InvariantPointAttention(nn.Module):
)
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
torch.cuda.empty_cache()
if self.is_multimer:
# [*, N_res, N_res, H, P_q, 3]
......@@ -396,11 +398,27 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H, P_q]
pt_att = sum([c**2 for c in pt_att])
else:
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
# # [*, N_res, N_res, H, P_q, 3]
# pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
# pt_att = pt_att**2
# # [*, N_res, N_res, H, P_q]
# pt_att = sum(torch.unbind(pt_att, dim=-1))
_chunks = 10
_ks = k_pts.unsqueeze(-5)
_qlist = torch.chunk(q_pts.unsqueeze(-4), chunks=_chunks, dim=0)
pt_att = None
for _i in range(_chunks):
_pt = _qlist[_i] - _ks
_pt = _pt**2
_pt = sum(torch.unbind(_pt, dim=-1))
if _i == 0:
pt_att = _pt
else:
pt_att = torch.cat([pt_att, _pt], dim=0)
torch.cuda.empty_cache()
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
......@@ -430,6 +448,8 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
torch.cuda.empty_cache()
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
......@@ -447,14 +467,34 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
else:
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# chunk permuted_pts
permuted_pts = permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
size0 = permuted_pts.size()[0]
a_size0 = a.size()[0]
if size0 == 1 or size0 != a_size0:
# # [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permuted_pts
),
dim=-2,
)
else:
a_lists = torch.chunk(a[..., None, :, :, None], size0, dim=0)
permuted_pts_lists = torch.chunk(permuted_pts, size0, dim=0)
_c = None
for i in range(size0):
_d = a_lists[i] * permuted_pts_lists[i]
_d = torch.sum(_d[..., None, :, :], dim=-2)
if i == 0 :
_c = _d
else:
_c = torch.cat([_c, _d], dim=0)
o_pt = torch.sum(_c, dim=-2)
del permuted_pts
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
......@@ -471,8 +511,11 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
torch.cuda.empty_cache()
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
del a
# [*, N_res, C_s]
if self.is_multimer:
......@@ -485,6 +528,7 @@ class InvariantPointAttention(nn.Module):
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
)
torch.cuda.empty_cache()
return s
......@@ -720,10 +764,16 @@ class StructureModule(nn.Module):
# [*, N, N, C_z]
z = self.layer_norm_z(z)
# inplace z
# z[0] = z[0].contiguous()
# torch.cuda.emtpy_cache()
# z[0] = self.layer_norm_z(z[0])
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
torch.cuda.empty_cache()
# [*, N]
rigids = Rigid.identity(
......@@ -737,9 +787,12 @@ class StructureModule(nn.Module):
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
del z
s = self.ipa_dropout(s)
torch.cuda.empty_cache()
s = self.layer_norm_ipa(s)
s = self.transition(s)
torch.cuda.empty_cache()
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
......
......@@ -434,21 +434,21 @@ def inference_monomer_model(args):
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
# amber_relaxer = relax.AmberRelaxation(
# use_gpu=True,
# **config.relax,
# )
# # Relax the prediction.
# t = time.perf_counter()
# relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
# print(f"Relaxation time: {time.perf_counter() - t}")
# # Save the relaxed PDB.
# relaxed_output_path = os.path.join(args.output_dir,
# f'{tag}_{args.model_name}_relaxed.pdb')
# with open(relaxed_output_path, 'w') as f:
# f.write(relaxed_pdb_str)
if __name__ == "__main__":
......
......@@ -14,4 +14,6 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files \
--jackhmmer_binary_path `which jackhmmer` \
--hhblits_binary_path `which hhblits` \
--hhsearch_binary_path `which hhsearch` \
--kalign_binary_path `which kalign`
--kalign_binary_path `which kalign` \
--chunk_size 4 \
--inplace
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment