"deploy/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "403344e53492bef1c5ba844912b80533e2fffcd7"
Commit e310bba4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory management in extra msa stack

parent cdadff32
...@@ -352,20 +352,30 @@ class ExtraMSABlock(nn.Module): ...@@ -352,20 +352,30 @@ class ExtraMSABlock(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024, _chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer( def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(torch.is_grad_enabled()):
m1 = m1 + m2
else:
m1 += m2
return m1
m = add(m, self.msa_dropout_layer(
self.msa_att_row( self.msa_att_row(
m.clone(), m.clone() if torch.is_grad_enabled() else m,
z=z.clone(), z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
) )
) ))
def fn(m, z): def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m, z = self.core( m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment