"docs/Makefile" did not exist on "f1cd1381528a773fa83bcb571f5778dc3a14c000"
Commit 9ce8713c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory efficiency, fix OpenMM CUDA + loss bugs

parent 754d2ba8
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
......@@ -44,6 +45,7 @@ from openfold.utils.loss import (
)
from openfold.utils.tensor_utils import (
dict_multimap,
tensor_tree_map,
)
......@@ -103,21 +105,28 @@ class AlphaFold(nn.Module):
self.config = config
def embed_templates(self, batch, z, pair_mask):
def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[-2]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
# Build template angle feats
angle_feats = atom37_to_torsion_angles(
batch["template_aatype"],
batch["template_all_atom_positions"],
batch["template_all_atom_masks"],
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],
single_template_feats["template_all_atom_masks"],
eps=1e-8
)
# Stow this away for later
batch["torsion_angles_mask"] = angle_feats["torsion_angles_mask"]
template_angle_feat = build_template_angle_feat(
angle_feats,
batch["template_aatype"],
single_template_feats["template_aatype"],
)
# [*, S_t, N, C_m]
......@@ -125,7 +134,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
batch,
single_template_feats,
eps=self.config.template.eps,
**self.config.template.distogram
)
......@@ -136,15 +145,30 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans
)
template_embeds.append({
"angle": a,
"pair": t,
"torsion_mask": angle_feats["torsion_angles_mask"]
})
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
template_embeds["pair"],
z,
template_mask=batch["template_mask"]
)
t *= torch.sum(batch["template_mask"]) > 0
return a, t
return {
"template_angle_embedding": a,
"template_pair_embedding": t,
"torsion_angles_mask": angle_feats["torsion_angles_mask"],
}
def forward(self, batch):
"""
......@@ -210,6 +234,7 @@ class AlphaFold(nn.Module):
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
......@@ -257,7 +282,7 @@ class AlphaFold(nn.Module):
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None # TODO: figure this part out
None
)
# m_1_prev_emb: [*, N, C_m]
......@@ -276,17 +301,28 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
a, t = self.embed_templates(feats, z, pair_mask)
template_feats = {
k:v for k,v in feats.items() if "template_" in k
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
# [*, N, N, C_z]
z += t
z += template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat([m, a], dim=-3)
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = feats["torsion_angles_mask"]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
......
......@@ -106,6 +106,7 @@ class MSAAttention(nn.Module):
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1)
)
biases = [bias]
if(self.pair_bias):
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
......@@ -116,14 +117,13 @@ class MSAAttention(nn.Module):
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4)
# [*, N_seq, no_heads, N_res, N_res]
bias = bias + z
biases.append(z)
mha_inputs = {
"q_x": m,
"k_x": m,
"v_x": m,
"biases": [bias]
"biases": biases
}
if(not self.training and self.chunk_size is not None):
m = chunk_layer(
......
......@@ -96,17 +96,6 @@ class TriangleAttention(nn.Module):
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
# Broadcasting and chunking doesn't really work yet (TODO)
# [*, I, H, I, J]
i = x.shape[-3]
triangle_bias = triangle_bias.expand(
(*((-1,) * len(triangle_bias.shape[:-4])), i, -1, -1, -1)
)
#print(x.shape)
#print(mask_bias.shape)
#print(triangle_bias.shape)
mha_inputs = {
"q_x": x,
"k_x": x,
......
......@@ -486,7 +486,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"], batch["target_feat"].shape[-1]
batch["template_aatype"], residue_constants.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
......@@ -502,19 +502,6 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
n, ca, c = [residue_constants.atom_order[a] for a in ['N', 'CA', 'C']]
#t_aa_pos = batch["template_all_atom_positions"]
#affines = T.make_transform_from_reference(
# n_xyz=t_aa_pos[..., n],
# ca_xyz=t_aa_pos[..., ca],
# c_xyz=t_aa_pos[..., c],
#)
#rots = affines.rots
#trans = affines.trans
#affine_vec = rot_mul_vec(
# rots.transpose(-1, -2),
# trans[..., None, :, :] - trans[..., None, :],
#)
#inverted_dists = torch.rsqrt(eps + torch.sum(inverted_dists**2, dim=-1))
t_aa_masks = batch["template_all_atom_masks"]
template_mask = (
......@@ -522,10 +509,6 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
#inverted_dists *= template_mask_2d
#unit_vector = affine_vec * inverted_dists.unsqueeze(-1)
#unit_vector = unit_vector.unsqueeze(-2)
unit_vector = template_mask_2d.new_zeros(*template_mask_2d.shape, 3)
to_concat.append(unit_vector)
to_concat.append(template_mask_2d[..., None])
......
......@@ -32,7 +32,7 @@ from openfold.utils.tensor_utils import (
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits),
labels * torch.nn.functional.log_softmax(logits, dim=-1),
dim=-1,
)
return loss
......@@ -219,18 +219,19 @@ def supervised_chi_loss(
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype, residue_constants.restype_num + 1,
).unsqueeze(-3)
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
residue_type_one_hot,
aatype.new_tensor(residue_constants.chi_pi_periodic)
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
true_chi = chi_angles.unsqueeze(-3)
true_chi = chi_angles
sin_true_chi = torch.sin(true_chi)
cos_true_chi = torch.cos(true_chi)
sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1)
......@@ -247,7 +248,7 @@ def supervised_chi_loss(
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = masked_mean(
sq_chi_error, chi_mask.unsqueeze(-3), dim=(-1, -2, -3)
chi_mask, sq_chi_error, dim=(-1, -2)
)
loss = 0
......@@ -258,7 +259,7 @@ def supervised_chi_loss(
)
norm_error = torch.abs(angle_norm - 1.)
angle_norm_loss = masked_mean(
norm_error, sequence_mask[..., None, :, None], dim=(-1, -2, -3)
seq_mask[..., None], norm_error, dim=(-1, -2)
)
loss += angle_norm_weight * angle_norm_loss
......@@ -390,11 +391,11 @@ def distogram_loss(
keepdims=True
)
true_bins = torch.sum(dists > sq_breaks, dim=-1)
true_bins = torch.sum(dists > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_bins, num_bins),
torch.nn.functional.one_hot(true_bins, no_bins),
)
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
......@@ -1240,6 +1241,7 @@ def experimentally_resolved_loss(
(resolution >= min_resolution) &
(resolution <= max_resolution)
)
return loss
......
......@@ -43,6 +43,19 @@ def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
new_dict[k] = dict_multimap(all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def stack_tensor_dicts(dicts):
first = dicts[0]
new_dict = {}
......@@ -154,13 +167,12 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
orig_batch_dims = [max(s) for s in zip(*initial_dims)]
def prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not sum(t.shape[:no_batch_dims]) == no_batch_dims):
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
return t
#shape = lambda t: t.shape
#print(tensor_tree_map(shape, inputs))
flattened_inputs = tensor_tree_map(prep_inputs, inputs)
flat_batch_dim = 1
......@@ -175,7 +187,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
out = None
for _ in range(no_chunks):
# Chunk the input
select_chunk = lambda t: t[i:i+chunk_size]
select_chunk = lambda t: t[i:i+chunk_size] if t.shape[0] != 1 else t
chunks = tensor_tree_map(select_chunk, flattened_inputs)
# Run the layer on the chunk
......
......@@ -14,8 +14,8 @@
# limitations under the License.
import os
import sys
sys.path.append("lib/conda/lib/python3.9/site-packages")
#import sys
#sys.path.append("lib/conda/lib/python3.9/site-packages")
import math
import pickle
......@@ -26,9 +26,13 @@ import numpy as np
from config import model_config
from openfold.model.model import AlphaFold
import openfold.np.protein as protein
from openfold.np import residue_constants, protein
#os.environ["OPENMM_DEFAULT_PLATFORM"] = "CPU"
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
#os.environ["OPENMM_CPU_THREADS"] = "16"
import openfold.np.relax.relax as relax
from openfold.np import residue_constants
from openfold.utils.import_weights import (
import_jax_weights_,
)
......@@ -41,37 +45,35 @@ from openfold.utils.tensor_utils import (
MODEL_NAME = "model_1"
MODEL_DEVICE = "cuda:1"
PARAM_PATH = "openfold/resources/params/params_model_1.npz"
FEAT_PATH = "tests/test_data/sample_feats.pickle"
#FEAT_PATH = "tests/test_data/sample_feats.pickle"
FEAT_PATH = "prediction/1OJN_feats.pickle"
config = model_config(MODEL_NAME)
model = AlphaFold(config.model)
model = model.eval()
import_jax_weights_(model, PARAM_PATH)
model_device = 'cuda:1'
model = model.to(model_device)
model = model.to(MODEL_DEVICE)
with open(FEAT_PATH, "rb") as f:
batch = pickle.load(f)
batch = {k:torch.as_tensor(v, device=model_device) for k,v in batch.items()}
with torch.no_grad():
batch = {k:torch.as_tensor(v, device=MODEL_DEVICE) for k,v in batch.items()}
longs = [
longs = [
"aatype",
"template_aatype",
"extra_msa",
"residx_atom37_to_atom14",
"residx_atom14_to_atom37",
]
for l in longs:
]
for l in longs:
batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous()
batch = tensor_tree_map(move_dim, batch)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous()
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
t = time.time()
out = model(batch)
print(f"Inference time: {time.time() - t}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment