Commit 143ba486 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor inplace operations, fix training

parent f1402490
...@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module): ...@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module):
tf: torch.Tensor, tf: torch.Tensor,
ri: torch.Tensor, ri: torch.Tensor,
msa: torch.Tensor, msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module): ...@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
""" """
inplace_safe = not (self.training or torch.is_grad_enabled())
# [*, N_res, c_z] # [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf) tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
...@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
_inplace: bool = False, inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
...@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module): ...@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
""" """
# [*, N, C_m] # [*, N, C_m]
m_update = self.layer_norm_m(m) m_update = self.layer_norm_m(m)
if(_inplace): if(inplace_safe):
m.copy_(m_update) m.copy_(m_update)
m_update = m m_update = m
# [*, N, N, C_z] # [*, N, N, C_z]
z_update = self.layer_norm_z(z) z_update = self.layer_norm_z(z)
if(_inplace): if(inplace_safe):
z.copy_(z_update) z.copy_(z_update)
z_update = z z_update = z
...@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
d = self.linear(d) d = self.linear(d)
z_update = add(z_update, d, _inplace) z_update = add(z_update, d, inplace_safe)
return m_update, z_update return m_update, z_update
......
...@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False, _offload_inference: bool = False,
...@@ -191,15 +192,12 @@ class EvoformerBlockCore(nn.Module): ...@@ -191,15 +192,12 @@ class EvoformerBlockCore(nn.Module):
# the original. # the original.
msa_trans_mask = msa_mask if _mask_trans else None msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
m, z = input_tensors m, z = input_tensors
# Need to dodge activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
m = add( m = add(
m, m,
self.msa_transition( self.msa_transition(
...@@ -213,9 +211,9 @@ class EvoformerBlockCore(nn.Module): ...@@ -213,9 +211,9 @@ class EvoformerBlockCore(nn.Module):
input_tensors[1] = input_tensors[1].cpu() input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
m, z = input_tensors m, z = input_tensors
opm = self.outer_product_mean( opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, _inplace=inplace_safe m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
) )
if(_offload_inference and inplace_safe): if(_offload_inference and inplace_safe):
...@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update = self.tri_mul_out( tmu_update = self.tri_mul_out(
z, z,
mask=pair_mask, mask=pair_mask,
_inplace=inplace_safe, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if(not inplace_safe): if(not inplace_safe):
...@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update = self.tri_mul_in( tmu_update = self.tri_mul_in(
z, z,
mask=pair_mask, mask=pair_mask,
_inplace=inplace_safe, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if(not inplace_safe): if(not inplace_safe):
...@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module): ...@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module):
z, z,
mask=pair_mask, mask=pair_mask,
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma use_lma=use_lma,
inplace_safe=inplace_safe,
) )
), ),
inplace=inplace_safe, inplace=inplace_safe,
...@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
mask=pair_mask.transpose(-1, -2), mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
) )
), ),
inplace=inplace_safe, inplace=inplace_safe,
...@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module): ...@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module):
) )
def forward(self, def forward(self,
input_tensors: Sequence[torch.Tensor], m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False, _offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors m, z = input_tensors
m = add(m, m = add(m,
...@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module): ...@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module):
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size, _attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference, _offload_inference=_offload_inference,
) )
if(inplace_safe): return m, z
out = input_tensors
else:
out = [m, z]
return out
class ExtraMSABlock(nn.Module): class ExtraMSABlock(nn.Module):
...@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module): ...@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module):
) )
def forward(self, def forward(self,
input_tensors: Sequence[torch.Tensor], m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False, _offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size _attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors m, z = input_tensors
inplace_safe = not (self.training or torch.is_grad_enabled())
# If function calls could speak...
m = add(m, m = add(m,
self.msa_dropout_layer( self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
...@@ -509,8 +519,11 @@ class ExtraMSABlock(nn.Module): ...@@ -509,8 +519,11 @@ class ExtraMSABlock(nn.Module):
inplace=inplace_safe, inplace=inplace_safe,
) )
if(not inplace_safe):
input_tensors = [m, z]
del m, z del m, z
def fn(input_tensors): def fn(input_tensors):
m = add(input_tensors[0], m = add(input_tensors[0],
self.msa_att_col( self.msa_att_col(
...@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module): ...@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module):
) )
if(not inplace_safe): if(not inplace_safe):
input_tensors [m, input_tensors[1]] input_tensors = [m, input_tensors[1]]
del m del m
...@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module): ...@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module):
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size, _attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference, _offload_inference=_offload_inference,
...@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module): ...@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module):
if(tune_chunk_size): if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner() self.chunk_size_tuner = ChunkSizeTuner()
def _forward_list(self, def _prep_blocks(self,
input_tensors: Sequence[torch.Tensor], m: torch.Tensor,
msa_mask: torch.Tensor, z: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool,
_mask_trans: bool = True, msa_mask: Optional[torch.Tensor],
_offload_inference: bool = False, pair_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: inplace_safe: bool,
_mask_trans: bool,
):
blocks = [ blocks = [
partial( partial(
b, b,
...@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module): ...@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module):
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_offload_inference=_offload_inference,
) )
for b in self.blocks for b in self.blocks
] ]
...@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module): ...@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks] blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
print("evo") assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
# We don't want to write in-place during chunk tuning runs # We don't want to write in-place during chunk tuning runs
args=([t.clone() for t in input_tensors],), args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [ blocks = [
...@@ -693,16 +708,43 @@ class EvoformerStack(nn.Module): ...@@ -693,16 +708,43 @@ class EvoformerStack(nn.Module):
) for b in blocks ) for b in blocks
] ]
blocks_per_ckpt = self.blocks_per_ckpt return blocks
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
m, z = checkpoint_blocks( def _forward_offload(self,
blocks, input_tensors: Sequence[torch.Tensor],
args=input_tensors, msa_mask: torch.Tensor,
blocks_per_ckpt=blocks_per_ckpt, pair_mask: torch.Tensor,
)[0] chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
m, z = input_tensors
s = self.linear(m[..., 0, :, :]) s = self.linear(m[..., 0, :, :])
return m, z, s return m, z, s
...@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module): ...@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
...@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module): ...@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module):
s: s:
[*, N_res, C_s] single embedding (or None if extra MSA stack) [*, N_res, C_s] single embedding (or None if extra MSA stack)
""" """
return self._forward_list( blocks = self._prep_blocks(
[m, z], m=m,
msa_mask=msa_mask, z=z,
pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=blocks_per_ckpt,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module): class ExtraMSAStack(nn.Module):
""" """
...@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module): ...@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module):
eps: float, eps: float,
ckpt: bool, ckpt: bool,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
chunk_msa_attn: bool = False,
tune_chunk_size: bool = False, tune_chunk_size: bool = False,
**kwargs, **kwargs,
): ):
...@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module): ...@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module):
self.ckpt = ckpt self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks self.clear_cache_between_blocks = clear_cache_between_blocks
self.chunk_msa_attn = chunk_msa_attn
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for _ in range(no_blocks): for _ in range(no_blocks):
block = ExtraMSABlock( block = ExtraMSABlock(
...@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module): ...@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
inf=inf, inf=inf,
eps=eps, eps=eps,
ckpt=ckpt if chunk_msa_attn else False, ckpt=False,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module): ...@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module):
use_lma: bool, use_lma: bool,
msa_mask: Optional[torch.Tensor], msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor], pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool, _mask_trans: bool,
): ):
blocks = [ blocks = [
...@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module): ...@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module):
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) for b in self.blocks ) for b in self.blocks
] ]
...@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module): ...@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module):
blocks = [partial(clear_cache, b) for b in blocks] blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
print("extra")
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=([m.clone(), z.clone()],), # Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size, min_chunk_size=chunk_size,
) )
blocks = [ blocks = [
...@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module): ...@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module):
return blocks return blocks
def _forward_list(self, def _forward_offload(self,
input_tensors: Sequence[torch.Tensor], input_tensors: Sequence[torch.Tensor],
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
_offload_inference: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert(not self.training) assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks( blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in # We are very careful not to create references to these tensors in
# this function # this function
...@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module): ...@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module):
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
for b in blocks: for b in blocks:
m, z = b(input_tensors, _offload_inference=_offload_inference) m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m input_tensors[0] = m
input_tensors[1] = z input_tensors[1] = z
del m, z del m, z
...@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module): ...@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module):
use_lma: bool = False, use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module): ...@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module):
use_lma=use_lma, use_lma=use_lma,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
for b in blocks: for b in blocks:
if(self.ckpt and torch.is_grad_enabled()): if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, (m, z)) m, z = checkpoint_fn(b, m, z)
else: else:
m, z = b(m, z) m, z = b(m, z)
......
...@@ -107,19 +107,16 @@ class AlphaFold(nn.Module): ...@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self.config["heads"], self.config["heads"],
) )
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
def embed_templates(self, batch, z, pair_mask, templ_dim):
if(self.template_config.offload_templates): if(self.template_config.offload_templates):
return embed_templates_offload( return embed_templates_offload(self,
self, batch, z, pair_mask, templ_dim, batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
elif(self.template_config.average_templates): elif(self.template_config.average_templates):
return embed_templates_average( return embed_templates_average(self,
self, batch, z, pair_mask, templ_dim batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
inplace_safe = not (self.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
pair_embeds = [] pair_embeds = []
n = z.shape[-2] n = z.shape[-2]
...@@ -168,6 +165,7 @@ class AlphaFold(nn.Module): ...@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
del t_pair del t_pair
...@@ -186,6 +184,11 @@ class AlphaFold(nn.Module): ...@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0) t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
ret = {} ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.template.embed_angles: if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
batch batch
...@@ -196,10 +199,6 @@ class AlphaFold(nn.Module): ...@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret["template_angle_embedding"] = a ret["template_angle_embedding"] = a
ret.update({"template_pair_embedding": t})
del t
return ret return ret
def iteration(self, feats, prevs, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
...@@ -218,6 +217,9 @@ class AlphaFold(nn.Module): ...@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n = feats["target_feat"].shape[-2] n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled()) inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features # Prep some features
...@@ -233,10 +235,11 @@ class AlphaFold(nn.Module): ...@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats["target_feat"], feats["target_feat"],
feats["residue_index"], feats["residue_index"],
feats["msa_feat"], feats["msa_feat"],
inplace_safe=inplace_safe,
) )
# Unpack the recycling embeddings. Removing them from the list allows # Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function. # them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)]) m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
...@@ -263,6 +266,7 @@ class AlphaFold(nn.Module): ...@@ -263,6 +266,7 @@ class AlphaFold(nn.Module):
feats["aatype"], x_prev, None feats["aatype"], x_prev, None
).to(dtype=z.dtype) ).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe): if(self.globals.offload_inference and inplace_safe):
m = m.cpu() m = m.cpu()
z = z.cpu() z = z.cpu()
...@@ -273,7 +277,7 @@ class AlphaFold(nn.Module): ...@@ -273,7 +277,7 @@ class AlphaFold(nn.Module):
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
_inplace=inplace_safe, inplace_safe=inplace_safe,
) )
if(self.globals.offload_inference and inplace_safe): if(self.globals.offload_inference and inplace_safe):
...@@ -286,10 +290,13 @@ class AlphaFold(nn.Module): ...@@ -286,10 +290,13 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = add(z, z_prev_emb, inplace=inplace_safe) z = add(z, z_prev_emb, inplace=inplace_safe)
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
...@@ -298,6 +305,7 @@ class AlphaFold(nn.Module): ...@@ -298,6 +305,7 @@ class AlphaFold(nn.Module):
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
inplace_safe=inplace_safe,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
...@@ -306,7 +314,7 @@ class AlphaFold(nn.Module): ...@@ -306,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe, inplace_safe,
) )
if self.config.template.embed_angles: if "template_angle_embedding" in template_embeds:
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_angle_embedding"]],
...@@ -324,39 +332,64 @@ class AlphaFold(nn.Module): ...@@ -324,39 +332,64 @@ class AlphaFold(nn.Module):
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) a = self.extra_msa_embedder(build_extra_msa_feat(feats))
input_tensors = [a, z]
del a, z
# [*, N, N, C_z] if(self.globals.offload_inference):
z = self.extra_msa_stack._forward_list( # To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors, input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
_offload_inference=self.globals.offload_inference,
) )
del input_tensors del input_tensors
else:
# Run MSA + pair embeddings through the trunk of the network m, z, s = self.evoformer(
# m: [*, S, N, C_m] m,
# z: [*, N, N, C_z] z,
# s: [*, N, C_s] msa_mask=msa_mask.to(dtype=m.dtype),
input_tensors = [m, z] pair_mask=pair_mask.to(dtype=z.dtype),
del m, z chunk_size=self.globals.chunk_size,
m, z, s = self.evoformer._forward_list( use_lma=self.globals.use_lma,
input_tensors, inplace_safe=inplace_safe,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype), _mask_trans=self.config._mask_trans,
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), )
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
outputs["msa"] = m[..., :n_seq, :, :] outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z outputs["pair"] = z
...@@ -369,6 +402,7 @@ class AlphaFold(nn.Module): ...@@ -369,6 +402,7 @@ class AlphaFold(nn.Module):
outputs, outputs,
feats["aatype"], feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype), mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference, _offload_inference=self.globals.offload_inference,
) )
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
......
...@@ -117,10 +117,9 @@ class MSAAttention(nn.Module): ...@@ -117,10 +117,9 @@ class MSAAttention(nn.Module):
def _prep_inputs(self, def _prep_inputs(self,
m: torch.Tensor, m: torch.Tensor,
z: Optional[torch.Tensor], z: Optional[torch.Tensor],
mask: Optional[torch.Tensor] mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: inplace_safe: bool = False,
_inplace_safe = not (self.training or torch.is_grad_enabled()) ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_seq, n_res = m.shape[-3:-1] n_seq, n_res = m.shape[-3:-1]
if mask is None: if mask is None:
# [*, N_seq, N_res] # [*, N_seq, N_res]
...@@ -163,6 +162,7 @@ class MSAAttention(nn.Module): ...@@ -163,6 +162,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor], mask: Optional[torch.Tensor],
chunk_logits: int, chunk_logits: int,
checkpoint: bool, checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
MSA attention with training-time chunking of the softmax computation. MSA attention with training-time chunking of the softmax computation.
...@@ -172,7 +172,9 @@ class MSAAttention(nn.Module): ...@@ -172,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM = -4 MSA_DIM = -4
def _get_qkv(m, z): def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
q, k, v = self.mha._prep_qkv(m, m) q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z return m, q, k, v, mask_bias, z
...@@ -208,6 +210,7 @@ class MSAAttention(nn.Module): ...@@ -208,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -229,10 +232,14 @@ class MSAAttention(nn.Module): ...@@ -229,10 +232,14 @@ class MSAAttention(nn.Module):
if(_chunk_logits is not None): if(_chunk_logits is not None):
return self._chunked_msa_attn( return self._chunked_msa_attn(
m=m, z=z, mask=mask, m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
) )
m, mask_bias, z = self._prep_inputs(m, z, mask) m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias] biases = [mask_bias]
if(z is not None): if(z is not None):
......
...@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module): ...@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_inplace: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module): ...@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm = norm + self.eps norm = norm + self.eps
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
if(_inplace): if(inplace_safe):
outer /= norm outer /= norm
else: else:
outer = outer / norm outer = outer / norm
......
...@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module): ...@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module):
z: Optional[torch.Tensor], z: Optional[torch.Tensor],
r: Rigid, r: Rigid,
mask: torch.Tensor, mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False, _offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None, _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -248,12 +249,11 @@ class InvariantPointAttention(nn.Module): ...@@ -248,12 +249,11 @@ class InvariantPointAttention(nn.Module):
Returns: Returns:
[*, N_res, C_s] single representation update [*, N_res, C_s] single representation update
""" """
inplace_safe = not (self.training or torch.is_grad_enabled())
if(_offload_inference and inplace_safe): if(_offload_inference and inplace_safe):
z = _z_reference_list z = _z_reference_list
else: else:
z = [z] z = [z]
####################################### #######################################
# Generate scalar and point activations # Generate scalar and point activations
####################################### #######################################
...@@ -619,6 +619,7 @@ class StructureModule(nn.Module): ...@@ -619,6 +619,7 @@ class StructureModule(nn.Module):
evoformer_output_dict, evoformer_output_dict,
aatype, aatype,
mask=None, mask=None,
inplace_safe=False,
_offload_inference=False, _offload_inference=False,
): ):
""" """
...@@ -674,6 +675,7 @@ class StructureModule(nn.Module): ...@@ -674,6 +675,7 @@ class StructureModule(nn.Module):
z, z,
rigids, rigids,
mask, mask,
inplace_safe=inplace_safe,
_offload_inference=_offload_inference, _offload_inference=_offload_inference,
_z_reference_list=z_reference_list _z_reference_list=z_reference_list
) )
......
...@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module): ...@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module):
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
_inplace: bool = False,
_attn_chunk_size: Optional[int] = None, _attn_chunk_size: Optional[int] = None,
): ):
if(_attn_chunk_size is None): if(_attn_chunk_size is None):
...@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks = [ single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3) m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
] ]
for i in range(len(single_templates)): for i in range(len(single_templates)):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
...@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module): ...@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
) )
), ),
_inplace, inplace_safe,
) )
single = add(single, single = add(single,
...@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module): ...@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=_attn_chunk_size, chunk_size=_attn_chunk_size,
mask=single_mask, mask=single_mask,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
) )
), ),
_inplace, inplace_safe,
) )
tmu_update = self.tri_mul_out( tmu_update = self.tri_mul_out(
single, single,
mask=single_mask, mask=single_mask,
_inplace=_inplace, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if(not _inplace): if(not inplace_safe):
single = single + self.dropout_row(tmu_update) single = single + self.dropout_row(tmu_update)
else: else:
single = tmu_update single = tmu_update
...@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module): ...@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update = self.tri_mul_in( tmu_update = self.tri_mul_in(
single, single,
mask=single_mask, mask=single_mask,
_inplace=_inplace, inplace_safe=inplace_safe,
_add_with_inplace=True, _add_with_inplace=True,
) )
if(not _inplace): if(not inplace_safe):
single = single + self.dropout_row(tmu_update) single = single + self.dropout_row(tmu_update)
else: else:
single = tmu_update single = tmu_update
...@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module): ...@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask=single_mask if _mask_trans else None, mask=single_mask if _mask_trans else None,
chunk_size=chunk_size, chunk_size=chunk_size,
), ),
_inplace, inplace_safe,
) )
if(not _inplace): if(not inplace_safe):
single_templates[i] = single single_templates[i] = single
if(not _inplace): if(not inplace_safe):
z = torch.cat(single_templates, dim=-4) z = torch.cat(single_templates, dim=-4)
return z return z
...@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module): ...@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module):
mask: torch.tensor, mask: torch.tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
): ):
""" """
...@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module): ...@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module):
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma, use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
_inplace=not (self.training or torch.is_grad_enabled()),
) )
for b in self.blocks for b in self.blocks
] ]
if(chunk_size is not None and self.chunk_size_tuner is not None): if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0], representative_fn=blocks[0],
args=(t.clone(),), args=(t.clone(),),
...@@ -411,6 +416,7 @@ def embed_templates_offload( ...@@ -411,6 +416,7 @@ def embed_templates_offload(
pair_mask, pair_mask,
templ_dim, templ_dim,
template_chunk_size=256, template_chunk_size=256,
inplace_safe=False,
): ):
""" """
Args: Args:
...@@ -435,8 +441,6 @@ def embed_templates_offload( ...@@ -435,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference. with GPU memory than the original. Useful for long-sequence inference.
""" """
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = [] pair_embeds_cpu = []
n = z.shape[-2] n = z.shape[-2]
...@@ -519,6 +523,7 @@ def embed_templates_average( ...@@ -519,6 +523,7 @@ def embed_templates_average(
pair_mask, pair_mask,
templ_dim, templ_dim,
templ_group_size=2, templ_group_size=2,
inplace_safe=False,
): ):
""" """
Args: Args:
...@@ -547,8 +552,6 @@ def embed_templates_average( ...@@ -547,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely. to scale almost indefinitely.
""" """
inplace_safe = not (model.training or torch.is_grad_enabled())
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2] n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
......
...@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module): ...@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"triangle! triangle!" "triangle! triangle!"
mha_inputs = { mha_inputs = {
...@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module): ...@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module):
"biases": biases, "biases": biases,
} }
inplace_safe = not (self.training or torch.is_grad_enabled())
return chunk_layer( return chunk_layer(
partial( partial(
self.mha, self.mha,
...@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module): ...@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module): ...@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module):
biases, biases,
chunk_size, chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma use_lma=use_lma,
inplace_safe=inplace_safe,
) )
else: else:
x = self.mha( x = self.mha(
......
...@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def forward(self, def forward(self,
z: torch.Tensor, z: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
_inplace: bool = False, inplace_safe: bool = False,
_add_with_inplace: bool = False, _add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256, _inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns: Returns:
[*, N_res, N_res, C_z] output tensor [*, N_res, N_res, C_z] output tensor
""" """
if(_inplace): if(inplace_safe):
x = self._inference_forward( x = self._inference_forward(
z, z,
mask, mask,
......
...@@ -415,8 +415,6 @@ class ChunkSizeTuner: ...@@ -415,8 +415,6 @@ class ChunkSizeTuner:
# Otherwise, we can reuse the precomputed value # Otherwise, we can reuse the precomputed value
consistent = False consistent = False
print(consistent)
if(not consistent): if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size( self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn, representative_fn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment