Commit 722a5e01 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve ease of use of LMA

parent 237e26c4
...@@ -80,6 +80,7 @@ def model_config(name, train=False, low_prec=False): ...@@ -80,6 +80,7 @@ def model_config(name, train=False, low_prec=False):
if train: if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
c.globals.use_lma = False
if low_prec: if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
...@@ -269,6 +270,7 @@ config = mlc.ConfigDict( ...@@ -269,6 +270,7 @@ config = mlc.ConfigDict(
"globals": { "globals": {
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"use_lma": False,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
......
...@@ -183,6 +183,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -183,6 +183,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
...@@ -192,21 +193,31 @@ class EvoformerBlockCore(nn.Module): ...@@ -192,21 +193,31 @@ class EvoformerBlockCore(nn.Module):
pair_trans_mask = pair_mask if _mask_trans else None pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition( m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size m, mask=msa_trans_mask, chunk_size=chunk_size,
) )
z = z + self.outer_product_mean( z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size m, mask=msa_mask, chunk_size=chunk_size,
) )
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer( z = z + self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma
)
) )
z = z + self.ps_dropout_col_layer( z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
) )
z = z + self.pair_transition( z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size z, mask=pair_trans_mask, chunk_size=chunk_size,
) )
return m, z return m, z
...@@ -267,18 +278,31 @@ class EvoformerBlock(nn.Module): ...@@ -267,18 +278,31 @@ class EvoformerBlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer( m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
)
m = m + self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
) )
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core( m, z = self.core(
m, m,
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
...@@ -350,7 +374,9 @@ class ExtraMSABlock(nn.Module): ...@@ -350,7 +374,9 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
_chunk_logits: Optional[int] = 1024, _chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
def add(m1, m2): def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's # The first operation in a checkpoint can't be in-place, but it's
...@@ -368,7 +394,8 @@ class ExtraMSABlock(nn.Module): ...@@ -368,7 +394,8 @@ class ExtraMSABlock(nn.Module):
z=z.clone() if torch.is_grad_enabled() else z, z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask, mask=msa_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_memory_efficient_kernel=not _chunk_logits, use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None, _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks= _checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False, self.ckpt if torch.is_grad_enabled() else False,
...@@ -376,9 +403,23 @@ class ExtraMSABlock(nn.Module): ...@@ -376,9 +403,23 @@ class ExtraMSABlock(nn.Module):
)) ))
def fn(m, z): def fn(m, z):
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)) m = add(
m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
)
m, z = self.core( m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
) )
return m, z return m, z
...@@ -488,6 +529,7 @@ class EvoformerStack(nn.Module): ...@@ -488,6 +529,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" """
...@@ -500,6 +542,8 @@ class EvoformerStack(nn.Module): ...@@ -500,6 +542,8 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
pair_mask: pair_mask:
[*, N_res, N_res] pair mask [*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
use_lma: Whether to use low-memory attention during inference
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
...@@ -514,6 +558,7 @@ class EvoformerStack(nn.Module): ...@@ -514,6 +558,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
for b in self.blocks for b in self.blocks
...@@ -591,6 +636,7 @@ class ExtraMSAStack(nn.Module): ...@@ -591,6 +636,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True, _mask_trans: bool = True,
...@@ -601,6 +647,8 @@ class ExtraMSAStack(nn.Module): ...@@ -601,6 +647,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding [*, N_extra, N_res, C_m] extra MSA embedding
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask: msa_mask:
Optional [*, N_extra, N_res] MSA mask Optional [*, N_extra, N_res] MSA mask
pair_mask: pair_mask:
...@@ -616,7 +664,9 @@ class ExtraMSAStack(nn.Module): ...@@ -616,7 +664,9 @@ class ExtraMSAStack(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
chunk_size=chunk_size, chunk_size=chunk_size,
_chunk_logits=None use_lma=use_lma,
_chunk_logits=None,
_mask_trans=_mask_trans,
) for b in self.blocks ) for b in self.blocks
] ]
...@@ -634,7 +684,15 @@ class ExtraMSAStack(nn.Module): ...@@ -634,7 +684,15 @@ class ExtraMSAStack(nn.Module):
m, z = b(m, z) m, z = b(m, z)
else: else:
for b in self.blocks: for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size) m, z = b(
m,
z,
msa_mask,
pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans
)
if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -152,6 +152,7 @@ class AlphaFold(nn.Module): ...@@ -152,6 +152,7 @@ class AlphaFold(nn.Module):
template_embeds["pair"], template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype), pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -161,6 +162,7 @@ class AlphaFold(nn.Module): ...@@ -161,6 +162,7 @@ class AlphaFold(nn.Module):
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
...@@ -294,6 +296,7 @@ class AlphaFold(nn.Module): ...@@ -294,6 +296,7 @@ class AlphaFold(nn.Module):
z, z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
...@@ -308,6 +311,7 @@ class AlphaFold(nn.Module): ...@@ -308,6 +311,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype), msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
) )
......
...@@ -90,12 +90,14 @@ class MSAAttention(nn.Module): ...@@ -90,12 +90,14 @@ class MSAAttention(nn.Module):
def _chunk(self, def _chunk(self,
m: torch.Tensor, m: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
chunk_size: int, chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
) -> torch.Tensor: ) -> torch.Tensor:
mha = partial( mha = partial(
self.mha, self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
) )
return chunk_layer( return chunk_layer(
mha, mha,
...@@ -193,6 +195,7 @@ class MSAAttention(nn.Module): ...@@ -193,6 +195,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
_chunk_logits: Optional[int] = None, _chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None, _checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -224,13 +227,20 @@ class MSAAttention(nn.Module): ...@@ -224,13 +227,20 @@ class MSAAttention(nn.Module):
biases.append(z) biases.append(z)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size) m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
)
else: else:
m = self.mha( m = self.mha(
q_x=m, q_x=m,
kv_x=m, kv_x=m,
biases=biases, biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel, use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
) )
return m return m
...@@ -305,7 +315,7 @@ class MSAColumnAttention(nn.Module): ...@@ -305,7 +315,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False, use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -323,7 +333,7 @@ class MSAColumnAttention(nn.Module): ...@@ -323,7 +333,7 @@ class MSAColumnAttention(nn.Module):
if mask is not None: if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size) m = self._msa_att(m, mask=mask, chunk_size=chunk_size, use_lma=use_lma)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
...@@ -360,13 +370,14 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -360,13 +370,14 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
mha_input = { mha_input = {
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
return chunk_layer( return chunk_layer(
self.global_attention, partial(self.global_attention, use_lma=use_lma),
mha_input, mha_input,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
...@@ -377,6 +388,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -377,6 +388,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor, m: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
...@@ -396,9 +408,9 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -396,9 +408,9 @@ class MSAColumnGlobalAttention(nn.Module):
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
if chunk_size is not None: if chunk_size is not None:
m = self._chunk(m, mask, chunk_size) m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else: else:
m = self.global_attention(m=m, mask=mask) m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
...@@ -31,6 +31,10 @@ from openfold.utils.tensor_utils import ( ...@@ -31,6 +31,10 @@ from openfold.utils.tensor_utils import (
) )
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
def _prod(nums): def _prod(nums):
out = 1 out = 1
for n in nums: for n in nums:
...@@ -403,8 +407,8 @@ class Attention(nn.Module): ...@@ -403,8 +407,8 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None, biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False, use_memory_efficient_kernel: bool = False,
use_lma: bool = False, use_lma: bool = False,
q_chunk_size: Optional[int] = None, q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
kv_chunk_size: Optional[int] = None, kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -460,6 +464,7 @@ class Attention(nn.Module): ...@@ -460,6 +464,7 @@ class Attention(nn.Module):
for b in biases for b in biases
] ]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
o = o.transpose(-2, -3)
else: else:
o = _attention(q, k, v, biases) o = _attention(q, k, v, biases)
o = o.transpose(-2, -3) o = o.transpose(-2, -3)
...@@ -494,7 +499,11 @@ class GlobalAttention(nn.Module): ...@@ -494,7 +499,11 @@ class GlobalAttention(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: def forward(self,
m: torch.Tensor,
mask: torch.Tensor,
use_lma: bool = False,
) -> torch.Tensor:
# [*, N_res, C_in] # [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps torch.sum(mask, dim=-1)[..., None] + self.eps
...@@ -511,12 +520,13 @@ class GlobalAttention(nn.Module): ...@@ -511,12 +520,13 @@ class GlobalAttention(nn.Module):
k = self.linear_k(m) k = self.linear_k(m)
v = self.linear_v(m) v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :]
if(not use_lma):
# [*, N_res, H, N_seq] # [*, N_res, H, N_seq]
a = torch.matmul( a = torch.matmul(
q, q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
) )
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias a += bias
a = softmax_no_cast(a) a = softmax_no_cast(a)
...@@ -525,6 +535,15 @@ class GlobalAttention(nn.Module): ...@@ -525,6 +535,15 @@ class GlobalAttention(nn.Module):
a, a,
v, v,
) )
else:
o = _lma(
q,
k,
v,
[bias],
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden] # [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m)) g = self.sigmoid(self.linear_g(m))
...@@ -552,12 +571,12 @@ def _lma( ...@@ -552,12 +571,12 @@ def _lma(
q_chunk_size: int, q_chunk_size: int,
kv_chunk_size: int, kv_chunk_size: int,
): ):
no_q, no_kv = q.shape[-3], k.shape[-3] no_q, no_kv = q.shape[-2], k.shape[-2]
# [*, Q, H, C_hidden] # [*, H, Q, C_hidden]
o = q.new_zeros(q.shape) o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size): for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :] q_chunk = q[..., q_s: q_s + q_chunk_size, :]
large_bias_chunks = [ large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases b[..., q_s: q_s + q_chunk_size, :] for b in biases
] ]
...@@ -566,24 +585,22 @@ def _lma( ...@@ -566,24 +585,22 @@ def _lma(
weights = [] weights = []
values = [] values = []
for kv_s in range(0, no_kv, kv_chunk_size): for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :] k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :] v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
small_bias_chunks = [ small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
] ]
a = torch.einsum( a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk, "...hqd,...hkd->...hqk", q_chunk, k_chunk,
) )
for b in small_bias_chunks: for b in small_bias_chunks:
a += b a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0] max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a) exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1)) maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1)) weights.append(torch.sum(exp_a, dim=-1))
...@@ -595,14 +612,14 @@ def _lma( ...@@ -595,14 +612,14 @@ def _lma(
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max) max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1) chunk_values = chunk_values * max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs chunk_weights = chunk_weights * max_diffs
all_values = torch.sum(chunk_values, dim=-4) all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
return o return o
...@@ -77,6 +77,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -77,6 +77,7 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor, t: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": z, "q_x": z,
...@@ -84,7 +85,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -84,7 +85,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
self.mha, partial(self.mha, use_lma=use_lma),
mha_inputs, mha_inputs,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]), no_batch_dims=len(z.shape[:-2]),
...@@ -95,7 +96,8 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -95,7 +96,8 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor, t: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None, template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -122,9 +124,9 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -122,9 +124,9 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
biases = [bias] biases = [bias]
if chunk_size is not None: if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size) z = self._chunk(z, t, biases, chunk_size, use_lma=use_lma)
else: else:
z = self.mha(q_x=z, kv_x=t, biases=biases) z = self.mha(q_x=z, kv_x=t, biases=biases, use_lma=use_lma)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = z.squeeze(-2) z = z.squeeze(-2)
...@@ -188,6 +190,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -188,6 +190,7 @@ class TemplatePairStackBlock(nn.Module):
z: torch.Tensor, z: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True _mask_trans: bool = True
): ):
single_templates = [ single_templates = [
...@@ -204,14 +207,16 @@ class TemplatePairStackBlock(nn.Module): ...@@ -204,14 +207,16 @@ class TemplatePairStackBlock(nn.Module):
self.tri_att_start( self.tri_att_start(
single, single,
chunk_size=chunk_size, chunk_size=chunk_size,
mask=single_mask mask=single_mask,
use_lma=use_lma,
) )
) )
single = single + self.dropout_col( single = single + self.dropout_col(
self.tri_att_end( self.tri_att_end(
single, single,
chunk_size=chunk_size, chunk_size=chunk_size,
mask=single_mask mask=single_mask,
use_lma=use_lma,
) )
) )
single = single + self.dropout_row( single = single + self.dropout_row(
...@@ -298,6 +303,7 @@ class TemplatePairStack(nn.Module): ...@@ -298,6 +303,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor, t: torch.tensor,
mask: torch.tensor, mask: torch.tensor,
chunk_size: int, chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True, _mask_trans: bool = True,
): ):
""" """
...@@ -320,6 +326,7 @@ class TemplatePairStack(nn.Module): ...@@ -320,6 +326,7 @@ class TemplatePairStack(nn.Module):
b, b,
mask=mask, mask=mask,
chunk_size=chunk_size, chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) )
for b in self.blocks for b in self.blocks
......
...@@ -62,6 +62,7 @@ class TriangleAttention(nn.Module): ...@@ -62,6 +62,7 @@ class TriangleAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
biases: List[torch.Tensor], biases: List[torch.Tensor],
chunk_size: int, chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
mha_inputs = { mha_inputs = {
"q_x": x, "q_x": x,
...@@ -69,7 +70,7 @@ class TriangleAttention(nn.Module): ...@@ -69,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases": biases, "biases": biases,
} }
return chunk_layer( return chunk_layer(
partial(self.mha), partial(self.mha, use_lma=use_lma),
mha_inputs, mha_inputs,
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]), no_batch_dims=len(x.shape[:-2]),
...@@ -78,7 +79,8 @@ class TriangleAttention(nn.Module): ...@@ -78,7 +79,8 @@ class TriangleAttention(nn.Module):
def forward(self, def forward(self,
x: torch.Tensor, x: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -113,9 +115,9 @@ class TriangleAttention(nn.Module): ...@@ -113,9 +115,9 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias] biases = [mask_bias, triangle_bias]
if chunk_size is not None: if chunk_size is not None:
x = self._chunk(x, biases, chunk_size) x = self._chunk(x, biases, chunk_size, use_lma=use_lma)
else: else:
x = self.mha(q_x=x, kv_x=x, biases=biases) x = self.mha(q_x=x, kv_x=x, biases=biases, use_lma=use_lma)
if not self.starting: if not self.starting:
x = x.transpose(-2, -3) x = x.transpose(-2, -3)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import argparse import argparse
from datetime import date from datetime import date
import gc
import logging import logging
import numpy as np import numpy as np
import os import os
...@@ -76,8 +77,9 @@ def main(args): ...@@ -76,8 +77,9 @@ def main(args):
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
for fasta_file in os.listdir(args.fasta_dir):
# Gather input sequences # Gather input sequences
with open(args.fasta_path, "r") as fp: with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read() data = fp.read()
lines = [ lines = [
...@@ -86,8 +88,10 @@ def main(args): ...@@ -86,8 +88,10 @@ def main(args):
][1:] ][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
for tag, seq in zip(tags, seqs): assert len(seqs) == 1, "Input FASTAs may only contain one sequence"
fasta_path = os.path.join(args.output_dir, "tmp.fasta") tag, seq = tags[0], seqs[0]
fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(fasta_path, "w") as fp: with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
...@@ -160,6 +164,7 @@ def main(args): ...@@ -160,6 +164,7 @@ def main(args):
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
if(not args.skip_relaxation):
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"), use_gpu=(args.model_device != "cpu"),
**config.relax, **config.relax,
...@@ -193,7 +198,8 @@ def main(args): ...@@ -193,7 +198,8 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"fasta_path", type=str, "fasta_dir", type=str,
help="Path to directory containing FASTA files, one sequence per file"
) )
parser.add_argument( parser.add_argument(
"template_mmcif_dir", type=str, "template_mmcif_dir", type=str,
...@@ -224,7 +230,7 @@ if __name__ == "__main__": ...@@ -224,7 +230,7 @@ if __name__ == "__main__":
openfold/resources/params""" openfold/resources/params"""
) )
parser.add_argument( parser.add_argument(
"--save_outputs", type=bool, default=False, "--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc." help="Whether to save all model outputs, including embeddings, etc."
) )
parser.add_argument( parser.add_argument(
...@@ -232,11 +238,14 @@ if __name__ == "__main__": ...@@ -232,11 +238,14 @@ if __name__ == "__main__":
help="""Number of CPUs with which to run alignment tools""" help="""Number of CPUs with which to run alignment tools"""
) )
parser.add_argument( parser.add_argument(
'--preset', type=str, default='full_dbs', "--preset", type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs') choices=('reduced_dbs', 'full_dbs')
) )
parser.add_argument( parser.add_argument(
'--data_random_seed', type=str, default=None "--data_random_seed", type=str, default=None
)
parser.add_argument(
"--skip_relaxation", action="store_true", default=False,
) )
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
from openfold.model.primitives import ( from openfold.model.primitives import (
Attention, Attention,
LowMemoryAttention,
) )
from tests.config import consts from tests.config import consts
...@@ -31,8 +30,7 @@ class TestLMA(unittest.TestCase): ...@@ -31,8 +30,7 @@ class TestLMA(unittest.TestCase):
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda() q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda() kv = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)] bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias] bias = [b.cuda() for b in bias]
...@@ -40,28 +38,13 @@ class TestLMA(unittest.TestCase): ...@@ -40,28 +38,13 @@ class TestLMA(unittest.TestCase):
gating_fill = torch.rand(c_hidden * no_heads, c_hidden) gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads) o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = LowMemoryAttention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention( a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda() ).cuda()
with torch.no_grad(): with torch.no_grad():
for n, p in lma.named_parameters(): l = a(q, kv, biases=bias, use_lma=True)
attrs = n.split('.') real = a(q, kv, biases=bias)
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)
for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)
with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias)
real = a(q, k, v, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment