Commit 722a5e01 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve ease of use of LMA

parent 237e26c4
......@@ -80,6 +80,7 @@ def model_config(name, train=False, low_prec=False):
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
c.globals.use_lma = False
if low_prec:
c.globals.eps = 1e-4
......@@ -269,6 +270,7 @@ config = mlc.ConfigDict(
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"use_lma": False,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
......
......@@ -183,6 +183,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
......@@ -192,21 +193,31 @@ class EvoformerBlockCore(nn.Module):
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
m, mask=msa_trans_mask, chunk_size=chunk_size,
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
m, mask=msa_mask, chunk_size=chunk_size,
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma
)
)
z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
)
z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size
z, mask=pair_trans_mask, chunk_size=chunk_size,
)
return m, z
......@@ -267,18 +278,31 @@ class EvoformerBlock(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
)
m = m + self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
)
......@@ -350,7 +374,9 @@ class ExtraMSABlock(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_chunk_logits: Optional[int] = 1024,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's
......@@ -368,7 +394,8 @@ class ExtraMSABlock(nn.Module):
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
use_memory_efficient_kernel=not _chunk_logits,
use_lma=use_lma,
use_memory_efficient_kernel=not _chunk_logits and not use_lma,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
......@@ -376,9 +403,23 @@ class ExtraMSABlock(nn.Module):
))
def fn(m, z):
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m = add(
m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
)
)
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
)
return m, z
......@@ -488,6 +529,7 @@ class EvoformerStack(nn.Module):
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
......@@ -500,6 +542,8 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
......@@ -514,6 +558,7 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
)
for b in self.blocks
......@@ -591,6 +636,7 @@ class ExtraMSAStack(nn.Module):
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
......@@ -601,6 +647,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
......@@ -616,7 +664,9 @@ class ExtraMSAStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_chunk_logits=None
use_lma=use_lma,
_chunk_logits=None,
_mask_trans=_mask_trans,
) for b in self.blocks
]
......@@ -634,7 +684,15 @@ class ExtraMSAStack(nn.Module):
m, z = b(m, z)
else:
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
m, z = b(
m,
z,
msa_mask,
pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans
)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
......
......@@ -152,6 +152,7 @@ class AlphaFold(nn.Module):
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
......@@ -161,6 +162,7 @@ class AlphaFold(nn.Module):
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
......@@ -294,6 +296,7 @@ class AlphaFold(nn.Module):
z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
)
......@@ -308,6 +311,7 @@ class AlphaFold(nn.Module):
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
......
......@@ -90,12 +90,14 @@ class MSAAttention(nn.Module):
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
use_memory_efficient_kernel: bool,
chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
)
return chunk_layer(
mha,
......@@ -193,6 +195,7 @@ class MSAAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
......@@ -224,13 +227,20 @@ class MSAAttention(nn.Module):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, use_memory_efficient_kernel, chunk_size)
m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
)
else:
m = self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
)
return m
......@@ -305,7 +315,7 @@ class MSAColumnAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -323,7 +333,7 @@ class MSAColumnAttention(nn.Module):
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......@@ -360,13 +370,14 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
return chunk_layer(
self.global_attention,
partial(self.global_attention, use_lma=use_lma),
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
......@@ -377,6 +388,7 @@ class MSAColumnGlobalAttention(nn.Module):
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
......@@ -396,9 +408,9 @@ class MSAColumnGlobalAttention(nn.Module):
m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else:
m = self.global_attention(m=m, mask=mask)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
......@@ -31,6 +31,10 @@ from openfold.utils.tensor_utils import (
)
DEFAULT_LMA_Q_CHUNK_SIZE=1024
DEFAULT_LMA_KV_CHUNK_SIZE=4096
def _prod(nums):
out = 1
for n in nums:
......@@ -403,8 +407,8 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
) -> torch.Tensor:
"""
Args:
......@@ -460,6 +464,7 @@ class Attention(nn.Module):
for b in biases
]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
o = o.transpose(-2, -3)
else:
o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
......@@ -494,7 +499,11 @@ class GlobalAttention(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def forward(self,
m: torch.Tensor,
mask: torch.Tensor,
use_lma: bool = False,
) -> torch.Tensor:
# [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
......@@ -511,20 +520,30 @@ class GlobalAttention(nn.Module):
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = softmax_no_cast(a)
if(not use_lma):
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
a += bias
a = softmax_no_cast(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
else:
o = _lma(
q,
k,
v,
[bias],
DEFAULT_LMA_Q_CHUNK_SIZE,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
......@@ -552,12 +571,12 @@ def _lma(
q_chunk_size: int,
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-3], k.shape[-3]
no_q, no_kv = q.shape[-2], k.shape[-2]
# [*, Q, H, C_hidden]
# [*, H, Q, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
q_chunk = q[..., q_s: q_s + q_chunk_size, :]
large_bias_chunks = [
b[..., q_s: q_s + q_chunk_size, :] for b in biases
]
......@@ -566,24 +585,22 @@ def _lma(
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
small_bias_chunks = [
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
]
a = torch.einsum(
"...qhd,...khd->...hqk", q_chunk, k_chunk,
"...hqd,...hkd->...hqk", q_chunk, k_chunk,
)
for b in small_bias_chunks:
a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
......@@ -595,14 +612,14 @@ def _lma(
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs
chunk_values = chunk_values * max_diffs.unsqueeze(-1)
chunk_weights = chunk_weights * max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
return o
......@@ -77,6 +77,7 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
......@@ -84,7 +85,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases": biases,
}
return chunk_layer(
self.mha,
partial(self.mha, use_lma=use_lma),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
......@@ -95,7 +96,8 @@ class TemplatePointwiseAttention(nn.Module):
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -122,9 +124,9 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
z = self._chunk(z, t, biases, chunk_size, use_lma=use_lma)
else:
z = self.mha(q_x=z, kv_x=t, biases=biases)
z = self.mha(q_x=z, kv_x=t, biases=biases, use_lma=use_lma)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
......@@ -188,6 +190,7 @@ class TemplatePairStackBlock(nn.Module):
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
_mask_trans: bool = True
):
single_templates = [
......@@ -204,14 +207,16 @@ class TemplatePairStackBlock(nn.Module):
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
mask=single_mask,
use_lma=use_lma,
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
mask=single_mask,
use_lma=use_lma,
)
)
single = single + self.dropout_row(
......@@ -298,6 +303,7 @@ class TemplatePairStack(nn.Module):
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
use_lma: bool = False,
_mask_trans: bool = True,
):
"""
......@@ -320,6 +326,7 @@ class TemplatePairStack(nn.Module):
b,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
_mask_trans=_mask_trans,
)
for b in self.blocks
......
......@@ -62,6 +62,7 @@ class TriangleAttention(nn.Module):
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_inputs = {
"q_x": x,
......@@ -69,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases": biases,
}
return chunk_layer(
partial(self.mha),
partial(self.mha, use_lma=use_lma),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
......@@ -78,7 +79,8 @@ class TriangleAttention(nn.Module):
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
"""
Args:
......@@ -113,9 +115,9 @@ class TriangleAttention(nn.Module):
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
x = self._chunk(x, biases, chunk_size, use_lma=use_lma)
else:
x = self.mha(q_x=x, kv_x=x, biases=biases)
x = self.mha(q_x=x, kv_x=x, biases=biases, use_lma=use_lma)
if not self.starting:
x = x.transpose(-2, -3)
......
......@@ -15,6 +15,7 @@
import argparse
from datetime import date
import gc
import logging
import numpy as np
import os
......@@ -76,18 +77,21 @@ def main(args):
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
with open(args.fasta_path, "r") as fp:
data = fp.read()
for fasta_file in os.listdir(args.fasta_dir):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
assert len(seqs) == 1, "Input FASTAs may only contain one sequence"
tag, seq = tags[0], seqs[0]
fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
......@@ -123,7 +127,7 @@ def main(args):
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
......@@ -160,27 +164,28 @@ def main(args):
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logging.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(not args.skip_relaxation):
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logging.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
......@@ -193,7 +198,8 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_path", type=str,
"fasta_dir", type=str,
help="Path to directory containing FASTA files, one sequence per file"
)
parser.add_argument(
"template_mmcif_dir", type=str,
......@@ -224,7 +230,7 @@ if __name__ == "__main__":
openfold/resources/params"""
)
parser.add_argument(
"--save_outputs", type=bool, default=False,
"--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc."
)
parser.add_argument(
......@@ -232,11 +238,14 @@ if __name__ == "__main__":
help="""Number of CPUs with which to run alignment tools"""
)
parser.add_argument(
'--preset', type=str, default='full_dbs',
"--preset", type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs')
)
parser.add_argument(
'--data_random_seed', type=str, default=None
"--data_random_seed", type=str, default=None
)
parser.add_argument(
"--skip_relaxation", action="store_true", default=False,
)
add_data_args(parser)
args = parser.parse_args()
......
......@@ -18,7 +18,6 @@ import unittest
from openfold.model.primitives import (
Attention,
LowMemoryAttention,
)
from tests.config import consts
......@@ -31,8 +30,7 @@ class TestLMA(unittest.TestCase):
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
k = torch.rand(batch_size, n, c_hidden).cuda()
v = torch.rand(batch_size, n, c_hidden).cuda()
kv = torch.rand(batch_size, n, c_hidden).cuda()
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
......@@ -40,28 +38,13 @@ class TestLMA(unittest.TestCase):
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = LowMemoryAttention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
for n, p in lma.named_parameters():
attrs = n.split('.')
param = a
for attr in attrs:
param = getattr(param, attr)
param.copy_(p)
for m in [lma, a]:
m.linear_g.weight.copy_(gating_fill)
m.linear_o.weight.copy_(o_fill)
with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias)
real = a(q, k, v, biases=bias)
l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment