Unverified Commit 92e6cf49 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

Merge pull request #18 from hpcaitech/sync_openfold_591d10d

sync with openfold 591d10d
parents 72444d5b 5f052a0a
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import ml_collections as mlc
......@@ -269,6 +255,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
"distillation_prob": 0.75,
},
"data_module": {
"use_small_bfd": False,
......@@ -347,6 +334,7 @@ config = mlc.ConfigDict(
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": False,
},
"extra_msa": {
"extra_msa_embedder": {
......
......@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes)
x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot
......@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
num_templates, -1
)
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
......@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
[new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
[new_order_list] * protein["msa"].shape[1],
device=protein["msa"].device,
).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
......@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
if seed is not None:
g.manual_seed(seed)
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
index_order = torch.cat(
(torch.tensor([0], device=shuffled.device), shuffled),
dim=0
)
num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel]
......@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
def block_delete_msa(protein, config):
num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor(
torch.tensor(num_seq, dtype=torch.float32)
torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
* config.msa_fraction_per_block
).to(torch.int32)
......@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat(
[torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
[
torch.ones(21, device=protein["msa"].device),
gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
torch.zeros(1, device=protein["msa"].device)
],
0,
)
......@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
)
segment_ids = segment_ids.expand(data.shape)
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
tensor = (
torch.zeros(*shape, device=segment_ids.device)
.scatter_add_(0, segment_ids, data.float())
)
tensor = tensor.type(data.dtype)
return tensor
......@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
@curry1
def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value)
protein[key] = torch.tensor(value, device=protein["msa"].device)
return protein
......@@ -431,7 +442,11 @@ def make_hhblits_profile(protein):
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0],
dtype=torch.float32,
device=protein["aatype"].device
)
categorical_probs = (
config.uniform_prob * random_aa
......@@ -644,7 +659,11 @@ def make_atom14_masks(protein):
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
batch = tree_map(
lambda n: torch.tensor(n, device=batch["aatype"].device),
batch,
np.ndarray
)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
......
......@@ -131,6 +131,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.template.use_unit_vector,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
......
......@@ -165,8 +165,6 @@ class RecyclingEmbedder(nn.Module):
self.no_bins = no_bins
self.inf = inf
self.bins = None
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
......@@ -191,8 +189,7 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if self.bins is None:
self.bins = torch.linspace(
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
......@@ -207,7 +204,7 @@ class RecyclingEmbedder(nn.Module):
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = self.bins ** 2
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
......
......@@ -90,7 +90,10 @@ def build_template_angle_feat(template_feats):
def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8
batch,
min_bin, max_bin, no_bins,
use_unit_vector=False,
eps=1e-20, inf=1e8
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
......@@ -101,7 +104,7 @@ def build_template_pair_feat(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]]
......@@ -143,6 +146,10 @@ def build_template_pair_feat(
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector):
unit_vector = unit_vector * 0.
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
......
......@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c1_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment