Commit e310bba4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory management in extra msa stack

parent cdadff32
......@@ -352,20 +352,30 @@ class ExtraMSABlock(nn.Module):
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(torch.is_grad_enabled()):
m1 = m1 + m2
else:
m1 += m2
return m1
m = add(m, self.msa_dropout_layer(
self.msa_att_row(
m.clone(),
z=z.clone(),
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
)
))
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment