Commit 9ce8713c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve memory efficiency, fix OpenMM CUDA + loss bugs

parent 754d2ba8
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -44,6 +45,7 @@ from openfold.utils.loss import ( ...@@ -44,6 +45,7 @@ from openfold.utils.loss import (
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap,
tensor_tree_map, tensor_tree_map,
) )
...@@ -103,21 +105,28 @@ class AlphaFold(nn.Module): ...@@ -103,21 +105,28 @@ class AlphaFold(nn.Module):
self.config = config self.config = config
def embed_templates(self, batch, z, pair_mask): def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[-2]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
# Build template angle feats # Build template angle feats
angle_feats = atom37_to_torsion_angles( angle_feats = atom37_to_torsion_angles(
batch["template_aatype"], single_template_feats["template_aatype"],
batch["template_all_atom_positions"], single_template_feats["template_all_atom_positions"],
batch["template_all_atom_masks"], single_template_feats["template_all_atom_masks"],
eps=1e-8 eps=1e-8
) )
# Stow this away for later
batch["torsion_angles_mask"] = angle_feats["torsion_angles_mask"]
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
angle_feats, angle_feats,
batch["template_aatype"], single_template_feats["template_aatype"],
) )
# [*, S_t, N, C_m] # [*, S_t, N, C_m]
...@@ -125,7 +134,7 @@ class AlphaFold(nn.Module): ...@@ -125,7 +134,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
batch, single_template_feats,
eps=self.config.template.eps, eps=self.config.template.eps,
**self.config.template.distogram **self.config.template.distogram
) )
...@@ -136,15 +145,30 @@ class AlphaFold(nn.Module): ...@@ -136,15 +145,30 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans _mask_trans=self.config._mask_trans
) )
template_embeds.append({
"angle": a,
"pair": t,
"torsion_mask": angle_feats["torsion_angles_mask"]
})
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
t, template_embeds["pair"],
z, z,
template_mask=batch["template_mask"] template_mask=batch["template_mask"]
) )
t *= torch.sum(batch["template_mask"]) > 0 t *= torch.sum(batch["template_mask"]) > 0
return a, t return {
"template_angle_embedding": a,
"template_pair_embedding": t,
"torsion_angles_mask": angle_feats["torsion_angles_mask"],
}
def forward(self, batch): def forward(self, batch):
""" """
...@@ -210,6 +234,7 @@ class AlphaFold(nn.Module): ...@@ -210,6 +234,7 @@ class AlphaFold(nn.Module):
# Grab some data about the input # Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2] batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2] n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3] n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device device = feats["target_feat"].device
...@@ -257,7 +282,7 @@ class AlphaFold(nn.Module): ...@@ -257,7 +282,7 @@ class AlphaFold(nn.Module):
x_prev = pseudo_beta_fn( x_prev = pseudo_beta_fn(
feats["aatype"], feats["aatype"],
x_prev, x_prev,
None # TODO: figure this part out None
) )
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
...@@ -276,17 +301,28 @@ class AlphaFold(nn.Module): ...@@ -276,17 +301,28 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled): if(self.config.template.enabled):
a, t = self.embed_templates(feats, z, pair_mask) template_feats = {
k:v for k,v in feats.items() if "template_" in k
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
# [*, N, N, C_z] # [*, N, N, C_z]
z += t z += template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles): if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat([m, a], dim=-3) m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N] # [*, S, N]
torsion_angles_mask = feats["torsion_angles_mask"] torsion_angles_mask = template_embeds["torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2 [feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
) )
......
...@@ -106,6 +106,7 @@ class MSAAttention(nn.Module): ...@@ -106,6 +106,7 @@ class MSAAttention(nn.Module):
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1) (*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1)
) )
biases = [bias]
if(self.pair_bias): if(self.pair_bias):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
...@@ -116,14 +117,13 @@ class MSAAttention(nn.Module): ...@@ -116,14 +117,13 @@ class MSAAttention(nn.Module):
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4) z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4)
# [*, N_seq, no_heads, N_res, N_res] biases.append(z)
bias = bias + z
mha_inputs = { mha_inputs = {
"q_x": m, "q_x": m,
"k_x": m, "k_x": m,
"v_x": m, "v_x": m,
"biases": [bias] "biases": biases
} }
if(not self.training and self.chunk_size is not None): if(not self.training and self.chunk_size is not None):
m = chunk_layer( m = chunk_layer(
......
...@@ -96,17 +96,6 @@ class TriangleAttention(nn.Module): ...@@ -96,17 +96,6 @@ class TriangleAttention(nn.Module):
# [*, 1, H, I, J] # [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4) triangle_bias = triangle_bias.unsqueeze(-4)
# Broadcasting and chunking doesn't really work yet (TODO)
# [*, I, H, I, J]
i = x.shape[-3]
triangle_bias = triangle_bias.expand(
(*((-1,) * len(triangle_bias.shape[:-4])), i, -1, -1, -1)
)
#print(x.shape)
#print(mask_bias.shape)
#print(triangle_bias.shape)
mha_inputs = { mha_inputs = {
"q_x": x, "q_x": x,
"k_x": x, "k_x": x,
......
...@@ -486,7 +486,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8 ...@@ -486,7 +486,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot( aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"], batch["target_feat"].shape[-1] batch["template_aatype"], residue_constants.restype_num + 2,
) )
n_res = batch["template_aatype"].shape[-1] n_res = batch["template_aatype"].shape[-1]
...@@ -502,19 +502,6 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8 ...@@ -502,19 +502,6 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
) )
n, ca, c = [residue_constants.atom_order[a] for a in ['N', 'CA', 'C']] n, ca, c = [residue_constants.atom_order[a] for a in ['N', 'CA', 'C']]
#t_aa_pos = batch["template_all_atom_positions"]
#affines = T.make_transform_from_reference(
# n_xyz=t_aa_pos[..., n],
# ca_xyz=t_aa_pos[..., ca],
# c_xyz=t_aa_pos[..., c],
#)
#rots = affines.rots
#trans = affines.trans
#affine_vec = rot_mul_vec(
# rots.transpose(-1, -2),
# trans[..., None, :, :] - trans[..., None, :],
#)
#inverted_dists = torch.rsqrt(eps + torch.sum(inverted_dists**2, dim=-1))
t_aa_masks = batch["template_all_atom_masks"] t_aa_masks = batch["template_all_atom_masks"]
template_mask = ( template_mask = (
...@@ -522,10 +509,6 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8 ...@@ -522,10 +509,6 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
) )
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
#inverted_dists *= template_mask_2d
#unit_vector = affine_vec * inverted_dists.unsqueeze(-1)
#unit_vector = unit_vector.unsqueeze(-2)
unit_vector = template_mask_2d.new_zeros(*template_mask_2d.shape, 3) unit_vector = template_mask_2d.new_zeros(*template_mask_2d.shape, 3)
to_concat.append(unit_vector) to_concat.append(unit_vector)
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
......
...@@ -32,7 +32,7 @@ from openfold.utils.tensor_utils import ( ...@@ -32,7 +32,7 @@ from openfold.utils.tensor_utils import (
def softmax_cross_entropy(logits, labels): def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum( loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits), labels * torch.nn.functional.log_softmax(logits, dim=-1),
dim=-1, dim=-1,
) )
return loss return loss
...@@ -219,18 +219,19 @@ def supervised_chi_loss( ...@@ -219,18 +219,19 @@ def supervised_chi_loss(
chi_weight: float, chi_weight: float,
angle_norm_weight: float, angle_norm_weight: float,
eps=1e-6, eps=1e-6,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
pred_angles = angles_sin_cos[..., 3:, :] pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot( residue_type_one_hot = torch.nn.functional.one_hot(
aatype, residue_constants.restype_num + 1, aatype, residue_constants.restype_num + 1,
).unsqueeze(-3) )
chi_pi_periodic = torch.einsum( chi_pi_periodic = torch.einsum(
"...ij,jk->ik", "...ij,jk->ik",
residue_type_one_hot, residue_type_one_hot.type(angles_sin_cos.dtype),
aatype.new_tensor(residue_constants.chi_pi_periodic) angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
) )
true_chi = chi_angles.unsqueeze(-3) true_chi = chi_angles
sin_true_chi = torch.sin(true_chi) sin_true_chi = torch.sin(true_chi)
cos_true_chi = torch.cos(true_chi) cos_true_chi = torch.cos(true_chi)
sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1) sin_cos_true_chi = torch.stack([sin_true_chi, cos_true_chi], dim=-1)
...@@ -247,7 +248,7 @@ def supervised_chi_loss( ...@@ -247,7 +248,7 @@ def supervised_chi_loss(
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = masked_mean( sq_chi_loss = masked_mean(
sq_chi_error, chi_mask.unsqueeze(-3), dim=(-1, -2, -3) chi_mask, sq_chi_error, dim=(-1, -2)
) )
loss = 0 loss = 0
...@@ -258,7 +259,7 @@ def supervised_chi_loss( ...@@ -258,7 +259,7 @@ def supervised_chi_loss(
) )
norm_error = torch.abs(angle_norm - 1.) norm_error = torch.abs(angle_norm - 1.)
angle_norm_loss = masked_mean( angle_norm_loss = masked_mean(
norm_error, sequence_mask[..., None, :, None], dim=(-1, -2, -3) seq_mask[..., None], norm_error, dim=(-1, -2)
) )
loss += angle_norm_weight * angle_norm_loss loss += angle_norm_weight * angle_norm_loss
...@@ -390,11 +391,11 @@ def distogram_loss( ...@@ -390,11 +391,11 @@ def distogram_loss(
keepdims=True keepdims=True
) )
true_bins = torch.sum(dists > sq_breaks, dim=-1) true_bins = torch.sum(dists > boundaries, dim=-1)
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, logits,
torch.nn.functional.one_hot(true_bins, num_bins), torch.nn.functional.one_hot(true_bins, no_bins),
) )
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
...@@ -1240,6 +1241,7 @@ def experimentally_resolved_loss( ...@@ -1240,6 +1241,7 @@ def experimentally_resolved_loss(
(resolution >= min_resolution) & (resolution >= min_resolution) &
(resolution <= max_resolution) (resolution <= max_resolution)
) )
return loss return loss
......
...@@ -43,6 +43,19 @@ def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): ...@@ -43,6 +43,19 @@ def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
return torch.bucketize(dists, boundaries) return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
new_dict[k] = dict_multimap(all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def stack_tensor_dicts(dicts): def stack_tensor_dicts(dicts):
first = dicts[0] first = dicts[0]
new_dict = {} new_dict = {}
...@@ -154,13 +167,12 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -154,13 +167,12 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
orig_batch_dims = [max(s) for s in zip(*initial_dims)] orig_batch_dims = [max(s) for s in zip(*initial_dims)]
def prep_inputs(t): def prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not sum(t.shape[:no_batch_dims]) == no_batch_dims):
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:]) t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:]) t = t.reshape(-1, *t.shape[no_batch_dims:])
return t return t
#shape = lambda t: t.shape
#print(tensor_tree_map(shape, inputs))
flattened_inputs = tensor_tree_map(prep_inputs, inputs) flattened_inputs = tensor_tree_map(prep_inputs, inputs)
flat_batch_dim = 1 flat_batch_dim = 1
...@@ -175,7 +187,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -175,7 +187,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
out = None out = None
for _ in range(no_chunks): for _ in range(no_chunks):
# Chunk the input # Chunk the input
select_chunk = lambda t: t[i:i+chunk_size] select_chunk = lambda t: t[i:i+chunk_size] if t.shape[0] != 1 else t
chunks = tensor_tree_map(select_chunk, flattened_inputs) chunks = tensor_tree_map(select_chunk, flattened_inputs)
# Run the layer on the chunk # Run the layer on the chunk
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
import os import os
import sys #import sys
sys.path.append("lib/conda/lib/python3.9/site-packages") #sys.path.append("lib/conda/lib/python3.9/site-packages")
import math import math
import pickle import pickle
...@@ -26,9 +26,13 @@ import numpy as np ...@@ -26,9 +26,13 @@ import numpy as np
from config import model_config from config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
import openfold.np.protein as protein from openfold.np import residue_constants, protein
#os.environ["OPENMM_DEFAULT_PLATFORM"] = "CPU"
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
#os.environ["OPENMM_CPU_THREADS"] = "16"
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.np import residue_constants
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
) )
...@@ -41,37 +45,35 @@ from openfold.utils.tensor_utils import ( ...@@ -41,37 +45,35 @@ from openfold.utils.tensor_utils import (
MODEL_NAME = "model_1" MODEL_NAME = "model_1"
MODEL_DEVICE = "cuda:1" MODEL_DEVICE = "cuda:1"
PARAM_PATH = "openfold/resources/params/params_model_1.npz" PARAM_PATH = "openfold/resources/params/params_model_1.npz"
FEAT_PATH = "tests/test_data/sample_feats.pickle" #FEAT_PATH = "tests/test_data/sample_feats.pickle"
FEAT_PATH = "prediction/1OJN_feats.pickle"
config = model_config(MODEL_NAME) config = model_config(MODEL_NAME)
model = AlphaFold(config.model) model = AlphaFold(config.model)
model = model.eval() model = model.eval()
import_jax_weights_(model, PARAM_PATH) import_jax_weights_(model, PARAM_PATH)
model_device = 'cuda:1' model = model.to(MODEL_DEVICE)
model = model.to(model_device)
with open(FEAT_PATH, "rb") as f: with open(FEAT_PATH, "rb") as f:
batch = pickle.load(f) batch = pickle.load(f)
batch = {k:torch.as_tensor(v, device=model_device) for k,v in batch.items()} with torch.no_grad():
batch = {k:torch.as_tensor(v, device=MODEL_DEVICE) for k,v in batch.items()}
longs = [ longs = [
"aatype", "aatype",
"template_aatype", "template_aatype",
"extra_msa", "extra_msa",
"residx_atom37_to_atom14", "residx_atom37_to_atom14",
"residx_atom14_to_atom37", "residx_atom14_to_atom37",
] ]
for l in longs: for l in longs:
batch[l] = batch[l].long() batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous()
batch = tensor_tree_map(move_dim, batch)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous()
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
t = time.time() t = time.time()
out = model(batch) out = model(batch)
print(f"Inference time: {time.time() - t}") print(f"Inference time: {time.time() - t}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment