"segmentation/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "7142c933b5dc64f33a88539e12108bdbcce11b5e"
Commit 143ba486 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor inplace operations, fix training

parent f1402490
......@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module):
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
......@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
"""
inplace_safe = not (self.training or torch.is_grad_enabled())
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
......@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
_inplace: bool = False,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
......@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
"""
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(_inplace):
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(_inplace):
if(inplace_safe):
z.copy_(z_update)
z_update = z
......@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
d = self.linear(d)
z_update = add(z_update, d, _inplace)
z_update = add(z_update, d, inplace_safe)
return m_update, z_update
......
......@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
......@@ -191,15 +192,12 @@ class EvoformerBlockCore(nn.Module):
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
# Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
m = add(
m,
self.msa_transition(
......@@ -213,9 +211,9 @@ class EvoformerBlockCore(nn.Module):
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if(_offload_inference and inplace_safe):
......@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
_inplace=inplace_safe,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
......@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update = self.tri_mul_in(
z,
mask=pair_mask,
_inplace=inplace_safe,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
......@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module):
z,
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
......@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
......@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module):
)
def forward(self,
input_tensors: Sequence[torch.Tensor],
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
m = add(m,
......@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
if(inplace_safe):
out = input_tensors
else:
out = [m, z]
return out
return m, z
class ExtraMSABlock(nn.Module):
......@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module):
)
def forward(self,
input_tensors: Sequence[torch.Tensor],
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
inplace_safe = not (self.training or torch.is_grad_enabled())
# If function calls could speak...
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
......@@ -509,8 +519,11 @@ class ExtraMSABlock(nn.Module):
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, z]
del m, z
def fn(input_tensors):
m = add(input_tensors[0],
self.msa_att_col(
......@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module):
)
if(not inplace_safe):
input_tensors [m, input_tensors[1]]
input_tensors = [m, input_tensors[1]]
del m
......@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
......@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module):
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _forward_list(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
......@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_offload_inference=_offload_inference,
)
for b in self.blocks
]
......@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
print("evo")
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# We don't want to write in-place during chunk tuning runs
args=([t.clone() for t in input_tensors],),
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
......@@ -693,16 +708,43 @@ class EvoformerStack(nn.Module):
) for b in blocks
]
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
return blocks
m, z = checkpoint_blocks(
blocks,
args=input_tensors,
blocks_per_ckpt=blocks_per_ckpt,
)[0]
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
m, z = input_tensors
s = self.linear(m[..., 0, :, :])
return m, z, s
......@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
......@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module):
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
return self._forward_list(
[m, z],
msa_mask=msa_mask,
pair_mask=pair_mask,
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=blocks_per_ckpt,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module):
"""
......@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module):
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
tune_chunk_size: bool = False,
**kwargs,
):
......@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module):
self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.chunk_msa_attn = chunk_msa_attn
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
......@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=ckpt if chunk_msa_attn else False,
ckpt=False,
)
self.blocks.append(block)
......@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module):
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
......@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
) for b in self.blocks
]
......@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
print("extra")
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=([m.clone(), z.clone()],),
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
......@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module):
return blocks
def _forward_list(self,
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
_offload_inference: bool = False,
) -> torch.Tensor:
assert(not self.training)
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
......@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module):
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(input_tensors, _offload_inference=_offload_inference)
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
......@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module):
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
......@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module):
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, (m, z))
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
......
......@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self.config["heads"],
)
def embed_templates(self, batch, z, pair_mask, templ_dim):
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
if(self.template_config.offload_templates):
return embed_templates_offload(
self, batch, z, pair_mask, templ_dim,
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif(self.template_config.average_templates):
return embed_templates_average(
self, batch, z, pair_mask, templ_dim
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
inplace_safe = not (self.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
......@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
del t_pair
......@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
......@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": t})
del t
return ret
def iteration(self, feats, prevs, _recycle=True):
......@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features
......@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function.
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
......@@ -263,6 +266,7 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
......@@ -273,7 +277,7 @@ class AlphaFold(nn.Module):
m_1_prev,
z_prev,
x_prev,
_inplace=inplace_safe,
inplace_safe=inplace_safe,
)
if(self.globals.offload_inference and inplace_safe):
......@@ -286,10 +290,13 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z = add(z, z_prev_emb, inplace=inplace_safe)
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
......@@ -298,6 +305,7 @@ class AlphaFold(nn.Module):
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
inplace_safe=inplace_safe,
)
# [*, N, N, C_z]
......@@ -306,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe,
)
if self.config.template.embed_angles:
if "template_angle_embedding" in template_embeds:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
......@@ -324,39 +332,64 @@ class AlphaFold(nn.Module):
if self.config.extra_msa.enabled:
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_list(
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
_offload_inference=self.globals.offload_inference,
)
del input_tensors
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_list(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
......@@ -369,6 +402,7 @@ class AlphaFold(nn.Module):
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
)
outputs["final_atom_positions"] = atom14_to_atom37(
......
......@@ -117,10 +117,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_inplace_safe = not (self.training or torch.is_grad_enabled())
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
......@@ -163,6 +162,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
......@@ -172,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
......@@ -208,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -229,10 +232,14 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
......
......@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_inplace: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm = norm + self.eps
# [*, N_res, N_res, C_z]
if(_inplace):
if(inplace_safe):
outer /= norm
else:
outer = outer / norm
......
......@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module):
z: Optional[torch.Tensor],
r: Rigid,
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
......@@ -248,12 +249,11 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
#######################################
# Generate scalar and point activations
#######################################
......@@ -619,6 +619,7 @@ class StructureModule(nn.Module):
evoformer_output_dict,
aatype,
mask=None,
inplace_safe=False,
_offload_inference=False,
):
"""
......@@ -674,6 +675,7 @@ class StructureModule(nn.Module):
z,
rigids,
mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list
)
......
......@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module):
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_inplace: bool = False,
_attn_chunk_size: Optional[int] = None,
):
if(_attn_chunk_size is None):
......@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
......@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
_inplace,
inplace_safe,
)
single = add(single,
......@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
_inplace,
inplace_safe,
)
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
_inplace=_inplace,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not _inplace):
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
_inplace=_inplace,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not _inplace):
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
......@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
),
_inplace,
inplace_safe,
)
if(not _inplace):
if(not inplace_safe):
single_templates[i] = single
if(not _inplace):
if(not inplace_safe):
z = torch.cat(single_templates, dim=-4)
return z
......@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module):
mask: torch.tensor,
chunk_size: int,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
):
"""
......@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module):
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_inplace=not (self.training or torch.is_grad_enabled()),
)
for b in self.blocks
]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
args=(t.clone(),),
......@@ -411,6 +416,7 @@ def embed_templates_offload(
pair_mask,
templ_dim,
template_chunk_size=256,
inplace_safe=False,
):
"""
Args:
......@@ -435,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = []
n = z.shape[-2]
......@@ -519,6 +523,7 @@ def embed_templates_average(
pair_mask,
templ_dim,
templ_group_size=2,
inplace_safe=False,
):
"""
Args:
......@@ -547,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
......
......@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = {
......@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module):
"biases": biases,
}
inplace_safe = not (self.training or torch.is_grad_enabled())
return chunk_layer(
partial(
self.mha,
......@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module):
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module):
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
x = self.mha(
......
......@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
_inplace: bool = False,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor:
......@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if(_inplace):
if(inplace_safe):
x = self._inference_forward(
z,
mask,
......
......@@ -415,8 +415,6 @@ class ChunkSizeTuner:
# Otherwise, we can reuse the precomputed value
consistent = False
print(consistent)
if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment