Commit 591d10d2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix template bugs

parent 8c169fb6
...@@ -255,6 +255,7 @@ config = mlc.ConfigDict( ...@@ -255,6 +255,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9, "clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
"uniform_recycling": True, "uniform_recycling": True,
"distillation_prob": 0.75,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
...@@ -333,6 +334,7 @@ config = mlc.ConfigDict( ...@@ -333,6 +334,7 @@ config = mlc.ConfigDict(
"eps": eps, # 1e-6, "eps": eps, # 1e-6,
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": False,
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { "extra_msa_embedder": {
......
...@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein): ...@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
def make_one_hot(x, num_classes): def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes) x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot return x_one_hot
...@@ -92,9 +92,9 @@ def fix_templates_aatype(protein): ...@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
) )
# Map hhsearch-aatype to our aatype. # Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand( new_order = torch.tensor(
num_templates, -1 new_order_list, dtype=torch.int64, device=protein["aatype"].device,
) ).expand(num_templates, -1)
protein["template_aatype"] = torch.gather( protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"] new_order, 1, index=protein["template_aatype"]
) )
...@@ -106,7 +106,8 @@ def correct_msa_restypes(protein): ...@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc.""" """Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor( new_order = torch.tensor(
[new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype [new_order_list] * protein["msa"].shape[1],
device=protein["msa"].device,
).transpose(0, 1) ).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"]) protein["msa"] = torch.gather(new_order, 0, protein["msa"])
...@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None): ...@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
if seed is not None: if seed is not None:
g.manual_seed(seed) g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1 shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0) index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
dim=0
)
num_sel = min(max_seq, num_seq) num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split( sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel] index_order, [num_sel, num_seq - num_sel]
...@@ -242,7 +246,7 @@ def delete_extra_msa(protein): ...@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
def block_delete_msa(protein, config): def block_delete_msa(protein, config):
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor( block_num_seq = torch.floor(
torch.tensor(num_seq, dtype=torch.float32) torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
* config.msa_fraction_per_block * config.msa_fraction_per_block
).to(torch.int32) ).to(torch.int32)
...@@ -275,7 +279,11 @@ def block_delete_msa(protein, config): ...@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
@curry1 @curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat( weights = torch.cat(
[torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)], [
torch.ones(21, device=protein["msa"].device),
gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
torch.zeros(1, device=protein["msa"].device)
],
0, 0,
) )
...@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments): ...@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
) )
segment_ids = segment_ids.expand(data.shape) segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:]) shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) tensor = (
torch.zeros(*shape, device=segment_ids.device)
.scatter_add_(0, segment_ids, data.float())
)
tensor = tensor.type(data.dtype) tensor = tensor.type(data.dtype)
return tensor return tensor
...@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""): ...@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
@curry1 @curry1
def add_constant_field(protein, key, value): def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value) protein[key] = torch.tensor(value, device=protein["msa"].device)
return protein return protein
...@@ -431,7 +442,11 @@ def make_hhblits_profile(protein): ...@@ -431,7 +442,11 @@ def make_hhblits_profile(protein):
def make_masked_msa(protein, config, replace_fraction): def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA.""" """Create data for BERT on raw MSA."""
# Add a random amino acid uniformly. # Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0],
dtype=torch.float32,
device=protein["aatype"].device
)
categorical_probs = ( categorical_probs = (
config.uniform_prob * random_aa config.uniform_prob * random_aa
...@@ -644,7 +659,11 @@ def make_atom14_masks(protein): ...@@ -644,7 +659,11 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch): def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) batch = tree_map(
lambda n: torch.tensor(n, device=batch["aatype"].device),
batch,
np.ndarray
)
out = make_atom14_masks(batch) out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out) out = tensor_tree_map(lambda t: np.array(t), out)
return out return out
......
...@@ -40,10 +40,11 @@ def np_to_tensor_dict( ...@@ -40,10 +40,11 @@ def np_to_tensor_dict(
Returns: Returns:
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
tensor_dict = { tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features k: torch.tensor(v) for k, v in np_example.items() if k in features
} }
return tensor_dict return tensor_dict
......
...@@ -165,8 +165,6 @@ class RecyclingEmbedder(nn.Module): ...@@ -165,8 +165,6 @@ class RecyclingEmbedder(nn.Module):
self.no_bins = no_bins self.no_bins = no_bins
self.inf = inf self.inf = inf
self.bins = None
self.linear = Linear(self.no_bins, self.c_z) self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m) self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
...@@ -191,15 +189,14 @@ class RecyclingEmbedder(nn.Module): ...@@ -191,15 +189,14 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
if self.bins is None: bins = torch.linspace(
self.bins = torch.linspace( self.min_bin,
self.min_bin, self.max_bin,
self.max_bin, self.no_bins,
self.no_bins, dtype=x.dtype,
dtype=x.dtype, device=x.device,
device=x.device, requires_grad=False,
requires_grad=False, )
)
# [*, N, C_m] # [*, N, C_m]
m_update = self.layer_norm_m(m) m_update = self.layer_norm_m(m)
...@@ -207,7 +204,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -207,7 +204,7 @@ class RecyclingEmbedder(nn.Module):
# This squared method might become problematic in FP16 mode. # This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I # I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time. # couldn't find in time.
squared_bins = self.bins ** 2 squared_bins = bins ** 2
upper = torch.cat( upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
) )
......
...@@ -131,6 +131,7 @@ class AlphaFold(nn.Module): ...@@ -131,6 +131,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
use_unit_vector=self.config.template.use_unit_vector,
inf=self.config.template.inf, inf=self.config.template.inf,
eps=self.config.template.eps, eps=self.config.template.eps,
**self.config.template.distogram, **self.config.template.distogram,
......
...@@ -90,7 +90,10 @@ def build_template_angle_feat(template_feats): ...@@ -90,7 +90,10 @@ def build_template_angle_feat(template_feats):
def build_template_pair_feat( def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8 batch,
min_bin, max_bin, no_bins,
use_unit_vector=False,
eps=1e-20, inf=1e8
): ):
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
...@@ -101,7 +104,7 @@ def build_template_pair_feat( ...@@ -101,7 +104,7 @@ def build_template_pair_feat(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True (tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
) )
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2 lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1) upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
...@@ -143,6 +146,10 @@ def build_template_pair_feat( ...@@ -143,6 +146,10 @@ def build_template_pair_feat(
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector):
unit_vector = unit_vector * 0.
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
......
...@@ -1352,8 +1352,8 @@ class Rigid: ...@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots[..., 0, 0] = cos_c2 c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2 c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1 c2_rots[..., 1, 1] = 1
c1_rots[..., 2, 0] = -1 * sin_c2 c2_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2 c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots) c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz) n_xyz = rot_vec_mul(c_rots, n_xyz)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment