Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for Kalign."""
import os
import subprocess
from typing import Sequence
from absl import logging
from . import utils
def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append(">" + name + "\n")
a3m.append(sequence + "\n")
return "".join(a3m)
class Kalign:
"""Python wrapper of the Kalign binary."""
def __init__(self, *, binary_path: str):
"""Initializes the Python Kalign wrapper.
Args:
binary_path: The path to the Kalign binary.
Raises:
RuntimeError: If Kalign binary not found within the path.
"""
self.binary_path = binary_path
def align(self, sequences: Sequence[str]) -> str:
"""Aligns the sequences and returns the alignment in A3M string.
Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.
Returns:
A string with the alignment in a3m format.
Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info("Aligning %d sequences", len(sequences))
for s in sequences:
if len(s) < 6:
raise ValueError(
"Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
with open(input_fasta_path, "w") as f:
f.write(_to_a3m(sequences))
cmd = [
self.binary_path,
"-i",
input_fasta_path,
"-o",
output_a3m_path,
"-format",
"fasta",
]
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("Kalign query"):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info(
"Kalign stdout:\n%s\n\nstderr:\n%s\n",
stdout.decode("utf-8"),
stderr.decode("utf-8"),
)
if retcode:
raise RuntimeError(
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(output_a3m_path) as f:
a3m = f.read()
return a3m
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import contextlib
import shutil
import tempfile
import time
from typing import Optional
from absl import logging
@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager
def timing(msg: str):
logging.info("Started %s", msg)
tic = time.time()
yield
toc = time.time()
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
from absl import logging
import json
import os
from typing import Mapping, Sequence
from unifold.data import protein
def get_chain_id_map(
sequences: Sequence[str],
descriptions: Sequence[str],
):
"""
Makes a mapping from PDB-format chain ID to sequence and description,
and parses the order of multi-chains
"""
unique_seqs = []
for seq in sequences:
if seq not in unique_seqs:
unique_seqs.append(seq)
chain_id_map = {
chain_id: {"descriptions": [], "sequence": seq}
for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs)
}
chain_order = []
for seq, des in zip(sequences, descriptions):
chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)]
chain_id_map[chain_id]["descriptions"].append(des)
chain_order.append(chain_id)
return chain_id_map, chain_order
def divide_multi_chains(
fasta_name: str,
output_dir_base: str,
sequences: Sequence[str],
descriptions: Sequence[str],
):
"""
Divides the multi-chains fasta into several single fasta files and
records multi-chains mapping information.
"""
if len(sequences) != len(descriptions):
raise ValueError(
"sequences and descriptions must have equal length. "
f"Got {len(sequences)} != {len(descriptions)}."
)
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError(
"Cannot process more chains than the PDB format supports. "
f"Got {len(sequences)} chains."
)
chain_id_map, chain_order = get_chain_id_map(sequences, descriptions)
output_dir = os.path.join(output_dir_base, fasta_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
chain_id_map_path = os.path.join(output_dir, "chain_id_map.json")
with open(chain_id_map_path, "w") as f:
json.dump(chain_id_map, f, indent=4, sort_keys=True)
chain_order_path = os.path.join(output_dir, "chains.txt")
with open(chain_order_path, "w") as f:
f.write(" ".join(chain_order))
logging.info(
"Mapping multi-chains fasta with chain order: %s", " ".join(chain_order)
)
temp_names = []
temp_paths = []
for chain_id in chain_id_map.keys():
temp_name = fasta_name + "_{}".format(chain_id)
temp_path = os.path.join(output_dir, temp_name + ".fasta")
des = "chain_{}".format(chain_id)
seq = chain_id_map[chain_id]["sequence"]
with open(temp_path, "w") as f:
f.write(">" + des + "\n" + seq)
temp_names.append(temp_name)
temp_paths.append(temp_path)
return temp_names, temp_paths
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.from .dataset import load_and_process_symmetry
''' Implementation for UF-Symmetry '''
from .model import UFSymmetry
from .config import uf_symmetry_config
from .assemble import assembly_from_prediction
from .dataset import load_and_process_symmetry
import torch
import numpy as np
from unifold.data.protein import Protein
from ..modules.featurization import atom14_to_atom37
from ..modules.frame import Frame
def expand_frames(frames: Frame, ops: Frame) -> torch.Tensor:
"""
Args:
frames: Rigid of shape [*, NR]
ops: Rigid of shape [NG]
Returns:
Tensor of shape [*, NGxNR, 4, 4]
"""
batch_shape = frames.shape[:-1]
ret = ops[..., None].compose(frames[..., None, :]).to_tensor_4x4()
ret = ret.reshape(*batch_shape, -1, 4, 4)
return ret
def expand_sc_frames(sc_frames: Frame, ops: Frame) -> torch.Tensor:
"""
Args:
frames: Rigid of shape [*, NR]
ops: Rigid of shape [NG]
Returns:
Tensor of shape [*, NGxNR, 4, 4]
"""
batch_shape = sc_frames.shape[:-2]
ret = ops[..., None, None].compose(sc_frames[..., None, :, :]).to_tensor_4x4()
ret = ret.reshape(*batch_shape, -1, sc_frames.shape[-1], 4, 4)
return ret
def expand_atom_positions(positions: torch.Tensor, ops: Frame) -> torch.Tensor:
"""
Args:
positions: Tensor of shape [*, NR, 37, 3]
ops: Rigid of shape [NG]
Returns:
Tensor of shape [*, NG * NR]
"""
batch_shape = positions.shape[:-3]
position_shape = positions.shape[-2:]
ret = ops[..., None, None].apply(positions[..., None, :, :, :])
ret = ret.reshape(*batch_shape, -1, *position_shape)
return ret
def expand_symmetry(sm_out, batch):
ops = Frame.from_tensor_4x4(batch["symmetry_opers"][-1, 0, ...].float()) # reduce recycle and batch dims.
num_expand = ops.shape[0]
frames = Frame.from_tensor_4x4(sm_out["frames"].float())
sidechain_frames = Frame.from_tensor_4x4(sm_out["sidechain_frames"].float())
positions = sm_out["positions"].float()
def repeat_fn(tensor, repeats, dim):
shape = [1 for _ in tensor.shape]
shape[dim] = repeats
return tensor.repeat(shape)
symm_out = {
"frames": expand_frames(frames, ops),
"sidechain_frames": expand_sc_frames(sidechain_frames, ops),
"unnormalized_angles": repeat_fn(sm_out["unnormalized_angles"], num_expand, dim=-3),
"angles": repeat_fn(sm_out["angles"], num_expand, dim=-3),
"single": repeat_fn(sm_out["single"], num_expand, dim=-2),
"positions": expand_atom_positions(positions, ops),
}
feats_expand_dims = {
"residx_atom37_to_atom14": -2,
"entity_id": -1,
"num_sym": -1,
"aatype": -1,
"residue_index": -1,
"atom37_atom_exists": -2,
"seq_mask": -1,
}
symm_feats = {
k: repeat_fn(batch[k], num_expand, dim=v)[-1] for k, v in feats_expand_dims.items() if k in batch
}
asym_id = batch["asym_id"]
def asym_fn(asym_id, i, num_asym):
ret = asym_id + num_asym * i
ret[asym_id == 0] = 0
return ret
asym_ids = torch.cat(
[asym_fn(asym_id, i, batch["num_asym"]) for i in range(num_expand)], dim=-1
).long()
symm_feats["asym_id"] = asym_ids[-1]
symm_feats["num_sym"] = symm_feats["num_sym"] * num_expand
symm_feats["num_asym"] = batch["num_asym"][-1] * num_expand
if "all_atom_positions" in batch:
symm_feats["all_atom_positions"] = expand_atom_positions(batch["all_atom_positions"], ops)[-1]
symm_feats["all_atom_mask"] = repeat_fn(batch["all_atom_mask"], num_expand, -2)[-1]
symm_out["expand_final_atom_positions"] = atom14_to_atom37(symm_out["positions"], symm_feats)
symm_out["expand_final_atom_mask"] = symm_feats["atom37_atom_exists"]
return symm_feats, symm_out
def assembly_from_prediction(
result,
b_factors=None) -> Protein:
chain_index = result["expand_batch"]["asym_id"]
aatype = result["expand_batch"]["aatype"]
residue_index = result["expand_batch"]["residue_index"]
atom_positions = result["expand_final_atom_positions"]
atom_mask = result["expand_final_atom_mask"]
if b_factors is None:
b_factors = np.zeros_like(atom_mask)
return Protein(
aatype=aatype,
atom_positions=atom_positions,
atom_mask=atom_mask,
residue_index=residue_index + 1,
chain_index=chain_index - 1,
b_factors=b_factors
)
import ml_collections as mlc
from ..config import model_config, recursive_set
def uf_symmetry_config():
config = model_config("multimer", train=False)
config.data.common.features.symmetry_opers = [None, 3, 3]
config.data.common.features.num_asym = [None]
config.data.common.features.pseudo_residue_feat = [None]
recursive_set(config, "max_msa_clusters", 256)
config.model.heads.pae.enabled = True # pTM is in development, not reliable.
config.loss.pae.weight = 0.0
config.model.heads.experimentally_resolved.enabled = True
config.loss.experimentally_resolved.weight = 0.0
config.model.pseudo_residue_embedder = mlc.ConfigDict({
"d_in": 8,
"d_hidden": 48,
"d_out": 48,
"num_blocks": 4,
})
config.model.input_embedder.pr_dim = 48
config.model.heads.pae.disable_enhance_head = True
return config
import numpy as np
import ml_collections as mlc
from typing import *
from unifold.symmetry.geometry_utils import get_transform
from ..dataset import load_and_process
import torch
def get_pseudo_residue_feat(symmetry: str):
circ = 2. * np.pi
symmetry = "C1" if symmetry is None else symmetry
if symmetry == 'C1':
ret = np.array([1., 0., 0., 0., 0., 0., 1., 0.], dtype=float)
elif symmetry[0] == 'C':
theta = circ / float(symmetry[1:])
ret = np.array([0., 1., 0., 0., 0., 0., np.cos(theta), np.sin(theta)], dtype=float)
elif symmetry[0] == 'D':
theta = circ / float(symmetry[1:])
ret = np.array([0., 0., 1., 0., 0., 0., np.cos(theta), np.sin(theta)], dtype=float)
elif symmetry == 'I':
ret = np.array([0., 0., 0., 1., 0., 0., 1., 0.], dtype=float)
elif symmetry == 'O':
ret = np.array([0., 0., 0., 0., 1., 0., 1., 0.], dtype=float)
elif symmetry == 'T':
ret = np.array([1., 0., 0., 0., 0., 1., 1., 0.], dtype=float)
elif symmetry == 'H':
raise NotImplementedError("helical structures not supported currently.")
else:
raise ValueError(f"unknown symmetry type {symmetry}")
return ret
def load_and_process_symmetry(
config: mlc.ConfigDict,
mode: str,
seed: int = 0,
batch_idx: Optional[int] = None,
data_idx: Optional[int] = None,
is_distillation: bool = False,
symmetry: str = 'C1',
**load_kwargs,
):
if mode == "train":
raise NotImplementedError("training UF-Symmetry not implemented.")
if not symmetry.startswith('C'):
raise NotImplementedError(f"symmetry group {symmetry} not supported currently.")
feats, _ = load_and_process(config, mode, seed, batch_idx, data_idx, is_distillation, **load_kwargs)
feats["symmetry_opers"] = torch.tensor(get_transform(symmetry), dtype=float)[None, :]
feats["pseudo_residue_feat"] = torch.tensor(get_pseudo_residue_feat(symmetry), dtype=float)[None, :]
feats["num_asym"] = torch.max(feats["asym_id"])[None]
return feats, None
import numpy as np
from typing import List, Tuple, Sequence, Optional
def get_rotation_from_axis_theta(axis: Sequence[float], theta: float) -> np.ndarray:
"""
Calculates a rotation matrix given an axis and angle.
Parameters
----------
axis : sequence of float
The rotation axis.
theta : float
The rotation angle.
Returns
-------
rot_mat : np.ndarray
The rotation matrix.
"""
assert len(axis) == 3
k_x, k_y, k_z = axis
c, s = np.cos(theta), np.sin(theta)
r_00 = c + (k_x ** 2) * (1 - c)
r_11 = c + (k_y ** 2) * (1 - c)
r_22 = c + (k_z ** 2) * (1 - c)
r_01 = -s * k_z + (1 - c) * k_x * k_y
r_10 = s * k_z + (1 - c) * k_x * k_y
r_20 = -s * k_y + (1 - c) * k_x * k_z
r_02 = s * k_y + (1 - c) * k_x * k_z
r_12 = -s * k_x + (1 - c) * k_y * k_z
r_21 = s * k_x + (1 - c) * k_y * k_z
return np.array([
[r_00, r_01, r_02],
[r_10, r_11, r_12],
[r_20, r_21, r_22],
], dtype=np.float64)
STANDARD_AXES_C = np.array([
[0, 0, 1],
], dtype=np.float64)
STANDARD_AXES_D = np.array([
[0, 0, 1],
[0, 1, 0],
], dtype=np.float64)
NUM_SYM_T = [3, 2, 2]
STANDARD_AXES_T = np.array([
[1 / np.sqrt(3), 1 / np.sqrt(3), 1 / np.sqrt(3)],
[0, 1, 0],
[1, 0, 0],
], dtype=np.float64)
NUM_SYM_O = [3, 4, 2]
STANDARD_AXES_O = np.array([
[1 / np.sqrt(3), 1 / np.sqrt(3), 1 / np.sqrt(3)],
[0, 0, 1],
[1, 0, 0],
], dtype=np.float64)
NUM_SYM_I = [3, 5, 2, 2]
STANDARD_AXES_I = np.array([
[0, -1 / np.sqrt(3), np.sqrt(2) / np.sqrt(3)],
[0, 0, 1],
[1 / 3, -2 * np.sqrt(2) / 3, 0],
[0, -np.sqrt(2) / np.sqrt(3), 1 / np.sqrt(3)],
], dtype=np.float64)
def get_standard_syms_axes(symmetry: str) -> Tuple[List[int], np.ndarray]:
"""
Get the information of spin axes of a symmetry type, including the axis vectors and their cyclic numbers.
Parameters
----------
symmetry : str
Symmetry type.
Returns
-------
list_num_sym : List of int
The axes' cyclic numbers.
standard_axes: np.ndarray
The spin axes (normalized).
"""
if symmetry.startswith('C'):
list_num_sym = [int(symmetry[1:])]
standard_axes = STANDARD_AXES_C
elif symmetry.startswith('D'):
list_num_sym = [int(symmetry[1:]), 2]
standard_axes = STANDARD_AXES_D
elif symmetry == 'T':
list_num_sym = NUM_SYM_T
standard_axes = STANDARD_AXES_T
elif symmetry == 'O':
list_num_sym = NUM_SYM_O
standard_axes = STANDARD_AXES_O
elif symmetry == 'I':
list_num_sym = NUM_SYM_I
standard_axes = STANDARD_AXES_I
else:
assert False, f'{symmetry}'
return list_num_sym, standard_axes
def get_num_AU(symmetry: Optional[str]):
"""
The get_num_AU function takes a string as input and returns the number of
asymmetric units in that symmetry group.
Parameters
----------
symmetry : str, optional
Symmetry type.
Returns
-------
num_AU : int
Number of asymmetric units.
"""
if symmetry is None:
return 1
elif symmetry.startswith('C'):
return int(symmetry[1:])
elif symmetry.startswith('D'):
return int(symmetry[1:]) * 2
elif symmetry == 'T':
return 12
elif symmetry == 'O':
return 24
elif symmetry == 'I':
return 60
elif symmetry == 'H':
raise NotImplementedError("helical structures not supported currently.")
else:
raise ValueError(f"unknown symmetry type {symmetry}")
def rotation_z(theta):
ca = np.cos(theta)
sa = np.sin(theta)
ret = np.array([[ ca, -sa, 0., 0.],
[ sa, ca, 0., 0.],
[ 0., 0., 1., 0.],
[ 0., 0., 0., 1.]])
ret[np.abs(ret) < 1e-10] = 0
return ret
def get_transform_C(grpnum):
interval = 2 * np.pi / grpnum
ret = np.stack([rotation_z(theta).astype(float) for theta in np.arange(0, 2 * np.pi, step=interval)])
return ret
def get_transform_D(grpnum):
assert grpnum % 2 == 0
c_transform = get_transform_C(grpnum // 2)
rot_y = np.array([[-1., 0., 0., 0.],
[ 0., 1., 0., 0.],
[ 0., 0., -1., 0.],
[ 0., 0., 0., 1.]])
ret = np.concatenate([c_transform, c_transform @ rot_y], axis=0)
return ret
def combine_rigid_groups(rigid_groups: List[List[np.ndarray]]) -> np.ndarray:
list_rigid = [np.eye(4, dtype=np.float64)]
for rigid_group in rigid_groups:
temp_list_rigid = []
for r1 in list_rigid:
for r2 in rigid_group:
temp_list_rigid.append(r2 @ r1)
list_rigid = temp_list_rigid
return np.stack(list_rigid)
def get_transform_TOI(symmetry: str) -> np.ndarray:
list_num_sym, standard_axes = get_standard_syms_axes(symmetry)
n = len(list_num_sym)
rigid_groups = []
for i in range(n):
num_sym, axis = list_num_sym[i], standard_axes[i, :]
angles = [j * 2 * np.pi / num_sym for j in range(num_sym)]
rigid_group = []
for angle in angles:
rigid = np.eye(4, dtype=np.float64)
rigid[:3, :3] = get_rotation_from_axis_theta(axis, angle)
rigid_group.append(rigid)
rigid_groups.append(rigid_group)
if symmetry == 'I':
xs = [0, 3, 2, 1]
rigid_groups = [rigid_groups[i] for i in xs]
return combine_rigid_groups(rigid_groups)
TRANSFORM_T = get_transform_TOI('T')
TRANSFORM_O = get_transform_TOI('O')
TRANSFORM_I = get_transform_TOI('I')
assert TRANSFORM_T.shape[0] == 12
assert TRANSFORM_O.shape[0] == 24
assert TRANSFORM_I.shape[0] == 60
def get_transform(symmetry: str) -> np.ndarray:
"""
Get symmetry operators of the given symmetry.
Parameters
----------
symmetry : str
Symmetry type.
Returns
-------
sym_opers: np.ndarray
(N * 4 * 4) Symmetry operators.
"""
if symmetry is None:
ret = get_transform_C(1)
elif symmetry.startswith('C'):
ret = get_transform_C(get_num_AU(symmetry))
elif symmetry.startswith('D'):
ret = get_transform_D(get_num_AU(symmetry))
elif symmetry == 'T':
ret = TRANSFORM_T
elif symmetry == 'O':
ret = TRANSFORM_O
elif symmetry == 'I':
ret = TRANSFORM_I
elif symmetry == 'H':
raise NotImplementedError("helical structures not supported currently.")
else:
raise ValueError(f"unknown symmetry type {symmetry}")
return ret
''' Yet to release. '''
\ No newline at end of file
import torch.nn.functional as F
from ..modules.alphafold import *
from .modules import SymmInputEmbedder, PseudoResidueEmbedder, SymmStructureModule
from .assemble import expand_symmetry
class UFSymmetry(AlphaFold): # inherit the main model. alterations implemented here.
def __init__(self, config):
assert not config.globals.alphafold_original_mode
super(UFSymmetry, self).__init__(config)
# replace input embedder with symm input embedder
self.input_embedder = SymmInputEmbedder(**config.model["input_embedder"], use_chain_relative=True)
self.pseudo_residue_embedder = PseudoResidueEmbedder(**config.model["pseudo_residue_embedder"])
self.structure_module = SymmStructureModule(
**config.model["structure_module"],
)
def __make_input_float__(self):
super().__make_input_float__()
self.pseudo_residue_embedder = self.pseudo_residue_embedder.float()
def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
batch_dims = feats["target_feat"].shape[:-2]
n = feats["target_feat"].shape[-2] + 1 # pr
seq_mask = feats["seq_mask"]
msa_mask = feats["msa_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
pr_feat = feats["pseudo_residue_feat"]
pr_feat = self.pseudo_residue_embedder(pr_feat)
m, z, pr_m = self.input_embedder(
feats["target_feat"],
feats["msa_feat"],
pr_feat
)
if m_1_prev is None:
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.d_msa),
requires_grad=False,
)
if z_prev is None:
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.d_pair),
requires_grad=False,
)
if x_prev is None:
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev_ = pseudo_beta_fn(feats["aatype"], x_prev[..., 1:, :, :], None)
x_prev = torch.cat([x_prev[..., 0:1, 0, :], x_prev_], dim=-2)
z += self.recycling_embedder.recyle_pos(x_prev)
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
)
m[..., 0, :, :] += m_1_prev_emb
z += z_prev_emb
relpos = self.input_embedder.relpos_emb(
feats["residue_index"].long(),
feats.get("sym_id", None),
feats.get("asym_id", None),
feats.get("entity_id", None),
feats.get("num_sym", None),
)
z[..., 1:, 1:, :] += relpos
m = m.type(self.dtype)
z = z.type(self.dtype)
tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(pair_mask, self.inf)
# m_in = m[..., 1:, :]
z_in = z[..., 1:, 1:, :]
if self.config.template.enabled:
template_mask = feats["template_mask"]
if torch.any(template_mask):
z_in = residual(
z_in,
self.embed_templates_pair(
feats,
z_in,
pair_mask,
tri_start_attn_mask,
tri_end_attn_mask,
templ_dim=-4,
),
self.training,
)
if self.config.extra_msa.enabled:
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
extra_msa_row_mask = gen_msa_attn_mask(
feats["extra_msa_mask"],
inf=self.inf,
gen_col_mask=False,
)
z_in = self.extra_msa_stack(
a,
z_in,
msa_mask=feats["extra_msa_mask"],
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
pair_mask=pair_mask,
msa_row_attn_mask=extra_msa_row_mask,
msa_col_attn_mask=None,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
)
if self.config.template.embed_angles:
template_1d_feat, template_1d_mask = self.embed_templates_angle(feats)
expand_shape = list(pr_m.shape)
expand_shape[-3] = template_1d_feat.shape[-3]
template_1d_feat = torch.cat([pr_m.expand(expand_shape), template_1d_feat], dim=-2)
msa_mask = torch.cat([feats["msa_mask"], template_1d_mask], dim=-2)
# compose tensors back
m = torch.cat([m, template_1d_feat], dim=-3)
z_tmp = torch.cat([z[..., 1:, 0:1, :], z_in], dim=-2)
z = torch.cat([z[..., 0:1, :, :], z_tmp], dim=-3)
# pad pr mask
pad_fn = lambda msk: F.pad(msk, (1, 0), "constant", 1.)
seq_mask = pad_fn(seq_mask)
msa_mask = pad_fn(msa_mask)
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_row_mask, msa_col_mask = gen_msa_attn_mask(
msa_mask,
inf=self.inf,
)
tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(pair_mask, self.inf)
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_mask,
msa_col_attn_mask=msa_col_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
)
return m, z, s, msa_mask, pair_mask, m_1_prev_emb, z_prev_emb
def iteration_evoformer_structure_module(
self, batch, m_1_prev, z_prev, x_prev, cycle_no, num_recycling, num_ensembles=1
):
z, s = 0, 0
n_seq = batch["msa_feat"].shape[-3]
assert num_ensembles >= 1
for ensemble_no in range(num_ensembles):
idx = cycle_no * num_ensembles + ensemble_no
fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...]
feats = tensor_tree_map(fetch_cur_batch, batch)
m, z0, s0, msa_mask, pair_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer(
feats, m_1_prev, z_prev, x_prev
)
z += z0
s += s0
del z0, s0
if num_ensembles > 1:
z /= float(num_ensembles)
s /= float(num_ensembles)
outputs = {}
outputs["msa"] = m[..., :n_seq, 1:, :]
outputs["pair"] = z[..., 1:, 1:, :]
outputs["single"] = s[..., 1:, :]
# norm loss
if (not getattr(self, "inference", False)) and num_recycling == (cycle_no + 1):
delta_msa = m
delta_msa[..., 0, :, :] = delta_msa[..., 0, :, :] - m_1_prev_emb.detach()
delta_pair = z - z_prev_emb.detach()
outputs["delta_msa"] = delta_msa
outputs["delta_pair"] = delta_pair
outputs["msa_norm_mask"] = msa_mask
outputs["pair_norm_mask"] = pair_mask
outputs["sm"] = self.structure_module(
s,
z,
feats["aatype"],
mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["pred_frame_tensor"] = outputs["sm"]["frames"][-1]
global_center = outputs['sm']['global_center_position'][..., None, :]
global_center = global_center.repeat(*((1,) * len(global_center.shape[:-2])), 37, 1)
x_prev = torch.cat([global_center, outputs["final_atom_positions"]], dim=-3)
# use float32 for numerical stability
if (not getattr(self, "inference", False)):
m_1_prev = m[..., 0, :, :].float()
z_prev = z.float()
x_prev = x_prev.float()
else:
m_1_prev = m[..., 0, :, :]
z_prev = z
x_prev = x_prev
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch, expand=True):
m_1_prev = batch.get("m_1_prev", None)
z_prev = batch.get("z_prev", None)
x_prev = batch.get("x_prev", None)
is_grad_enabled = torch.is_grad_enabled()
num_iters = int(batch["num_recycling_iters"]) + 1
num_ensembles = int(batch["msa_mask"].shape[0]) // num_iters
if self.training:
# don't use ensemble during training
assert num_ensembles == 1
# convert dtypes in batch
batch = self.__convert_input_dtype__(batch)
for cycle_no in range(num_iters):
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
(
outputs,
m_1_prev,
z_prev,
x_prev,
) = self.iteration_evoformer_structure_module(
batch,
m_1_prev,
z_prev,
x_prev,
cycle_no=cycle_no,
num_recycling=num_iters,
num_ensembles=num_ensembles,
)
if not is_final_iter:
del outputs
if expand:
symmetry_feat, symmetry_output = expand_symmetry(outputs["sm"], batch)
outputs["expand_batch"] = symmetry_feat
outputs["expand_sm"] = symmetry_output
outputs["expand_final_atom_positions"] = symmetry_output["expand_final_atom_positions"]
outputs["expand_final_atom_mask"] = symmetry_output["expand_final_atom_mask"]
if "asym_id" in batch:
outputs["asym_id"] = batch["asym_id"][0, ...]
outputs.update(self.aux_heads(outputs))
return outputs
import torch
import torch.nn as nn
import torch.nn.functional as F
from unifold.modules.structure_module import *
from ..modules.common import Linear
from ..modules.embedders import InputEmbedder
from typing import *
import torch
import torch.nn as nn
class PseudoResidueResnetBlock(nn.Module):
def __init__(self, c_hidden):
"""
Args:
c_hidden:
Hidden channel dimension
"""
super(PseudoResidueResnetBlock, self).__init__()
self.c_hidden = c_hidden
self.linear_1 = Linear(self.c_hidden, self.c_hidden)
self.act = nn.GELU()
self.linear_2 = Linear(self.c_hidden, self.c_hidden)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_0 = x
x = self.act(x)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x + x_0
class PseudoResidueEmbedder(nn.Module):
def __init__(
self,
d_in: int,
d_out: int,
d_hidden: int,
num_blocks: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(PseudoResidueEmbedder, self).__init__()
self.d_in = d_in
self.d_out = d_out
self.d_hidden = d_hidden
self.num_blocks = num_blocks
self.linear_in = Linear(self.d_in, self.d_hidden)
self.act = nn.GELU()
self.layers = nn.ModuleList()
for _ in range(self.num_blocks):
layer = PseudoResidueResnetBlock(c_hidden=self.d_hidden)
self.layers.append(layer)
self.linear_out = Linear(self.d_hidden, self.d_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, C_in] pseudo residue feature
Returns:
[*, C_out] embedding
"""
x = x.type(self.linear_in.weight.dtype)
x = self.linear_in(x)
x = self.act(x)
for l in self.layers:
x = l(x)
x = self.linear_out(x)
return x
class SymmInputEmbedder(InputEmbedder):
def __init__(
self,
pr_dim: Optional[int] = None,
**kwargs,
):
super(SymmInputEmbedder, self).__init__(**kwargs)
d_pair = kwargs.get("d_pair")
d_msa = kwargs.get("d_msa")
self.pr_dim = pr_dim
self.linear_pr_z_i = Linear(pr_dim, d_pair)
self.linear_pr_z_j = Linear(pr_dim, d_pair)
self.linear_pr_m = Linear(pr_dim, d_msa)
def forward(
self,
tf: torch.Tensor,
msa: torch.Tensor,
prf: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# [*, N_res, c_z]
if self.tf_dim == 21:
# multimer use 21 target dim
tf = tf[...,1:]
# convert type if necessary
tf = tf.type(self.linear_tf_z_i.weight.dtype)
msa = msa.type(self.linear_tf_z_i.weight.dtype)
tf_emb_i = self.linear_tf_z_i(tf) # [*, N_res, c_z]
tf_emb_j = self.linear_tf_z_j(tf)
pr_emb_i = self.linear_pr_z_i(prf) # [*, c_z]
pr_emb_j = self.linear_pr_z_j(prf)
tf_emb_i = torch.cat([pr_emb_i[..., None, :], tf_emb_i], dim=-2)
tf_emb_j = torch.cat([pr_emb_j[..., None, :], tf_emb_j], dim=-2)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., :, None, :] + tf_emb_j[..., None, :, :]
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
pr_m = self.linear_pr_m(prf)[..., None, None, :]
pr_m_expand = pr_m.expand((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1))
msa_emb = torch.cat([pr_m_expand, msa_emb], dim=-2)
return msa_emb, pair_emb, pr_m
class SymmStructureModule(StructureModule):
def forward(
self,
s,
z,
aatype,
mask=None,
):
if mask is None:
mask = s.new_ones(s.shape[:-1])
mask = F.pad(mask, (1, 0), "constant", 1.)
# generate square mask
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3)
s = self.layer_norm_s(s)
z = self.layer_norm_z(z)
initial_s = s
s = self.linear_in(s)
quat_encoder = Quaternion.identity(
s.shape[:-1],
s.dtype,
s.device,
requires_grad=False,
)
backb_to_global = Frame(
Rotation(
mat=quat_encoder.get_rot_mats(),
),
quat_encoder.get_trans(),
)
outputs = []
for i in range(self.num_blocks):
s = residual(s, self.ipa(s, z, backb_to_global, square_mask), self.training)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# update quaternion encoder
# use backb_to_global to avoid quat-to-rot conversion
quat_encoder = quat_encoder.compose_update_vec(
self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()
)
# initial_s is always used to update the backbone
unnormalized_angles, angles = self.angle_resnet(s[..., 1:, :], initial_s[..., 1:, :])
# convert quaternion to rotation matrix
backb_to_global = Frame(
Rotation(
mat=quat_encoder.get_rot_mats(),
),
quat_encoder.get_trans(),
)
global_frame = backb_to_global[..., 0:1]
local_frames = backb_to_global[..., 1:]
local_frames = global_frame.compose(local_frames)
preds = {
"frames": local_frames.scale_translation(
self.trans_scale_factor
).to_tensor_4x4(), # no pr
"unnormalized_angles": unnormalized_angles,
"angles": angles,
}
outputs.append(preds)
if i < (self.num_blocks - 1):
# stop gradient in iteration
quat_encoder = quat_encoder.stop_rot_gradient()
backb_to_global = backb_to_global.stop_rot_gradient()
else:
all_frames_to_global = self.torsion_angles_to_frames(
local_frames.scale_translation(self.trans_scale_factor),
angles,
aatype,
) # no pr
pred_positions = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
) # no pr
outputs = dict_multimap(torch.stack, outputs)
outputs["sidechain_frames"] = all_frames_to_global.to_tensor_4x4()
outputs["positions"] = pred_positions
outputs["single"] = s[..., 1:, :]
outputs["global_center_position"] = global_frame.get_trans()
return outputs
\ No newline at end of file
import logging
import os
import contextlib
from typing import Optional
import numpy as np
from unifold.dataset import UnifoldDataset, UnifoldMultimerDataset
from unicore.data import data_utils
from unicore.tasks import UnicoreTask, register_task
logger = logging.getLogger(__name__)
@register_task("af2")
class AlphafoldTask(UnicoreTask):
"""Task for training masked language models (e.g., BERT)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
)
parser.add_argument("--disable-sd", action="store_true")
parser.add_argument(
"--json-prefix",
type=str,
default="",
)
parser.add_argument(
"--max-chains",
type=int,
default=18,
)
parser.add_argument(
"--sd-prob",
type=float,
default=0.75,
)
def __init__(self, args):
super().__init__(args)
self.seed = args.seed
@classmethod
def setup_task(cls, args, **kwargs):
return cls(args)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.config.model.is_multimer:
data_class = UnifoldMultimerDataset
else:
data_class = UnifoldDataset
if split == "train":
dataset = data_class(
self.args,
self.args.seed + 81,
self.config,
self.args.data,
mode="train",
max_step=self.args.max_update,
disable_sd=self.args.disable_sd,
json_prefix=self.args.json_prefix,
)
else:
dataset = data_class(
self.args,
self.args.seed + 81,
self.config,
self.args.data,
mode="eval",
max_step=None,
json_prefix=self.args.json_prefix,
)
self.datasets[split] = dataset
def build_model(self, args):
from unicore import models
model = models.build_model(args, self)
self.config = model.config
return model
def disable_shuffling(self) -> bool:
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment