Commit 42427a1c authored by Shenggan's avatar Shenggan
Browse files

add inference pipeline from openfold/alphafold

parent b3b3b445
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Restrained Amber Minimization of a structure."""
import io
import time
from typing import Collection, Optional, Sequence, Dict
from absl import logging
from fastfold.common import (
protein,
residue_constants,
)
import fastfold.model.loss as loss
from fastfold.relax import cleanup, utils
import ml_collections
import numpy as np
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms
def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
"""Returns True if the atom will be restrained by the given restraint set."""
if rset == "non_hydrogen":
return atom.element.name != "hydrogen"
elif rset == "c_alpha":
return atom.name == "CA"
def _add_restraints(
system: openmm.System,
reference_pdb: openmm_app.PDBFile,
stiffness: unit.Unit,
rset: str,
exclude_residues: Sequence[int],
):
"""Adds a harmonic potential that restrains the system to a structure."""
assert rset in ["non_hydrogen", "c_alpha"]
force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
force.addGlobalParameter("k", stiffness)
for p in ["x0", "y0", "z0"]:
force.addPerParticleParameter(p)
for i, atom in enumerate(reference_pdb.topology.atoms()):
if atom.residue.index in exclude_residues:
continue
if will_restrain(atom, rset):
force.addParticle(i, reference_pdb.positions[i])
logging.info(
"Restraining %d / %d particles.",
force.getNumParticles(),
system.getNumParticles(),
)
system.addForce(force)
def _openmm_minimize(
pdb_str: str,
max_iterations: int,
tolerance: unit.Unit,
stiffness: unit.Unit,
restraint_set: str,
exclude_residues: Sequence[int],
use_gpu: bool,
):
"""Minimize energy via openmm."""
pdb_file = io.StringIO(pdb_str)
pdb = openmm_app.PDBFile(pdb_file)
force_field = openmm_app.ForceField("amber99sb.xml")
constraints = openmm_app.HBonds
system = force_field.createSystem(pdb.topology, constraints=constraints)
if stiffness > 0 * ENERGY / (LENGTH**2):
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
ret = {}
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
return ret
def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
"""Returns a pdb string provided OpenMM topology and positions."""
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, positions, f)
return f.getvalue()
def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
"""Checks that no atom positions have been altered by cleaning."""
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
for ref_res, cl_res in zip(reference.topology.residues(), cleaned.topology.residues()):
assert ref_res.name == cl_res.name
for rat in ref_res.atoms():
for cat in cl_res.atoms():
if cat.name == rat.name:
if not np.array_equal(cl_xyz[cat.index], ref_xyz[rat.index]):
raise ValueError(f"Coordinates of cleaned atom {cat} do not match "
f"coordinates of reference atom {rat}.")
def _check_residues_are_well_defined(prot: protein.Protein):
"""Checks that all residues contain non-empty atom sets."""
if (prot.atom_mask.sum(axis=-1) == 0).any():
raise ValueError("Amber minimization can only be performed on proteins with"
" well-defined residues. This protein contains at least"
" one residue with no atoms.")
def _check_atom_mask_is_ideal(prot):
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask = prot.atom_mask
ideal_atom_mask = protein.ideal_atom_mask(prot)
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
def clean_protein(prot: protein.Protein, checks: bool = True):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
"""
_check_atom_mask_is_ideal(prot)
# Clean pdb.
prot_pdb_string = protein.to_pdb(prot)
pdb_file = io.StringIO(prot_pdb_string)
alterations_info = {}
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
fixed_pdb_file = io.StringIO(fixed_pdb)
pdb_structure = PdbStructure(fixed_pdb_file)
cleanup.clean_structure(pdb_structure, alterations_info)
logging.info("alterations info: %s", alterations_info)
# Write pdb file of cleaned structure.
as_file = openmm_app.PDBFile(pdb_structure)
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
return pdb_string
def make_atom14_positions(prot):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0) for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein.
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (np.take_along_axis(
prot["all_atom_positions"],
residx_atom14_to_atom37[..., None],
axis=1,
))
prot["atom14_atom_exists"] = residx_atom14_mask
prot["atom14_gt_exists"] = residx_atom14_gt_mask
prot["atom14_gt_positions"] = residx_atom14_gt_positions
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
# Create the gather indices for mapping back.
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
# Create the corresponding mask.
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
prot["atom37_atom_exists"] = residx_atom37_mask
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [residue_constants.restype_1to3[res] for res in residue_constants.restypes]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[resname].index(
source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[resname].index(
target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[prot["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = np.einsum("rac,rab->rbc", residx_atom14_gt_positions,
renaming_transform)
prot["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = np.einsum("ra,rab->rb", residx_atom14_gt_mask, renaming_transform)
prot["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[residue_constants.restype_3to1[resname]]
atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[prot["aatype"]]
return prot
def find_violations(prot_np: protein.Protein):
"""Analyzes a protein and returns structural violation information.
Args:
prot_np: A protein.
Returns:
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
"""
batch = {
"aatype": prot_np.aatype,
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
"residue_index": prot_np.residue_index,
}
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
batch = make_atom14_positions(batch)
violations = loss.find_structural_violations_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
config=ml_collections.ConfigDict({
"violation_tolerance_factor": 12, # Taken from model config.
"clash_overlap_tolerance": 1.5, # Taken from model config.
}),
)
violation_metrics = loss.compute_violation_metrics_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
violations=violations,
)
return violations, violation_metrics
def get_violation_metrics(prot: protein.Protein):
"""Computes violation and alignment metrics."""
structural_violations, struct_metrics = find_violations(prot)
violation_idx = np.flatnonzero(structural_violations["total_per_residue_violations_mask"])
struct_metrics["residue_violations"] = violation_idx
struct_metrics["num_residue_violations"] = len(violation_idx)
struct_metrics["structural_violations"] = structural_violations
return struct_metrics
def _run_one_iteration(
*,
pdb_string: str,
max_iterations: int,
tolerance: float,
stiffness: float,
restraint_set: str,
max_attempts: int,
exclude_residues: Optional[Collection[int]] = None,
use_gpu: bool,
):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
use_gpu: Whether to run relaxation on GPU
Returns:
A `dict` of minimization info.
"""
exclude_residues = exclude_residues or []
# Assign physical dimensions.
tolerance = tolerance * ENERGY
stiffness = stiffness * ENERGY / (LENGTH**2)
start = time.perf_counter()
minimized = False
attempts = 0
while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info("Minimizing protein, attempt %d of %d.", attempts, max_attempts)
ret = _openmm_minimize(
pdb_string,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
exclude_residues=exclude_residues,
use_gpu=use_gpu,
)
minimized = True
except Exception as e: # pylint: disable=broad-except
print(e)
logging.info(e)
if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.perf_counter() - start
ret["min_attempts"] = attempts
return ret
def run_pipeline(
prot: protein.Protein,
stiffness: float,
use_gpu: bool,
max_outer_iterations: int = 1,
place_hydrogens_every_iteration: bool = True,
max_iterations: int = 0,
tolerance: float = 2.39,
restraint_set: str = "non_hydrogen",
max_attempts: int = 100,
checks: bool = True,
exclude_residues: Optional[Sequence[int]] = None,
):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
per relax iteration. A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
The default value is the OpenMM default.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts per iteration.
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
iteration = 0
while violations > 0 and iteration < max_outer_iterations:
ret = _run_one_iteration(
pdb_string=pdb_string,
exclude_residues=exclude_residues,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
max_attempts=max_attempts,
use_gpu=use_gpu,
)
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
else:
pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot))
ret.update({
"num_exclusions": len(exclude_residues),
"iteration": iteration,
})
violations = ret["violations_per_residue"]
exclude_residues = exclude_residues.union(ret["residue_violations"])
logging.info(
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d ",
ret["einit"],
ret["efinal"],
ret["opt_time"],
ret["num_residue_violations"],
ret["num_exclusions"],
)
iteration += 1
return ret
def get_initial_energies(
pdb_strs: Sequence[str],
stiffness: float = 0.0,
restraint_set: str = "non_hydrogen",
exclude_residues: Optional[Sequence[int]] = None,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues = exclude_residues or []
openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs]
force_field = openmm_app.ForceField("amber99sb.xml")
system = force_field.createSystem(openmm_pdbs[0].topology, constraints=openmm_app.HBonds)
stiffness = stiffness * ENERGY / (LENGTH**2)
if stiffness > 0 * ENERGY / (LENGTH**2):
_add_restraints(system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues)
simulation = openmm_app.Simulation(
openmm_pdbs[0].topology,
system,
openmm.LangevinIntegrator(0, 0.01, 0.0),
openmm.Platform.getPlatformByName("CPU"),
)
energies = []
for pdb in openmm_pdbs:
try:
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
except Exception as e: # pylint: disable=broad-except
logging.error("Error getting initial energy, returning large value %s", e)
energies.append(unit.Quantity(1e20, ENERGY))
return energies
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
fix_pdb uses a third-party tool. We also support fixing some additional edge
cases like removing chains of length one (see clean_structure).
"""
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info['nonstandard_residues'] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info['missing_residues'] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info['missing_heavy_atoms'] = fixer.missingAtoms
alterations_info['missing_terminals'] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle,
keepIds=True)
return out_handle.getvalue()
def clean_structure(pdb_structure, alterations_info):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
def _remove_heterogens(fixer, alterations_info, keep_water):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info['removed_heterogens'] = (
initial_resnames.difference(final_resnames))
def _replace_met_se(pdb_structure, alterations_info):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == 'MET':
s_atom = res.get_atom('SD')
if s_atom.element_symbol == 'Se':
s_atom.element_symbol = 'S'
s_atom.element = element.get_by_symbol('S')
modified_met_residues.append(s_atom.residue_number)
alterations_info['Se_in_MET'] = modified_met_residues
def _remove_chains_of_length_one(pdb_structure, alterations_info):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info['removed_chains'] = removed_chains
\ No newline at end of file
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Amber relaxation."""
from typing import Any, Dict, Sequence, Tuple
from fastfold.common import protein
from fastfold.relax import amber_minimize
from fastfold.relax import utils
import numpy as np
class AmberRelaxation(object):
"""Amber relaxation."""
def __init__(
self,
*,
max_iterations: int,
tolerance: float,
stiffness: float,
exclude_residues: Sequence[int],
max_outer_iterations: int,
use_gpu: bool,
):
"""Initialize Amber Relaxer.
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
use_gpu: Whether to run on GPU
"""
self._max_iterations = max_iterations
self._tolerance = tolerance
self._stiffness = stiffness
self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations
self._use_gpu = use_gpu
def process(
self, *, prot: protein.Protein
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot,
max_iterations=self._max_iterations,
tolerance=self._tolerance,
stiffness=self._stiffness,
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
)
min_pos = out["pos"]
start_pos = out["posinit"]
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
debug_data = {
"initial_energy": out["einit"],
"final_energy": out["efinal"],
"attempts": out["min_attempts"],
"rmsd": rmsd,
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
)
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
return min_pdb, debug_data, violations
\ No newline at end of file
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for minimization."""
import io
import numbers
import collections
from fastfold.common import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
Args:
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
Returns:
A new PDB string with the B-factors replaced.
"""
if bfactors.shape[-1] != residue_constants.atom_type_num:
raise ValueError(f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')
parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str)
structure = parser.get_structure('', handle)
curr_resid = ('', '', '')
idx = -1
for atom in structure.get_atoms():
atom_resid = atom.parent.get_id()
if atom_resid != curr_resid:
idx += 1
if idx >= bfactors.shape[0]:
raise ValueError('Index into bfactors exceeds number of residues. '
'B-factors shape: {shape}, idx: {idx}.')
curr_resid = atom_resid
atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']]
new_pdb = io.StringIO()
pdb_io = PDB.PDBIO()
pdb_io.set_structure(structure)
pdb_io.save(new_pdb)
return new_pdb.getvalue()
def assert_equal_nonterminal_atom_types(atom_mask: np.ndarray, ref_atom_mask: np.ndarray):
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order['OXT']
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask])
def mask_mean_np(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
"""Masked mean with numpy."""
if drop_mask_channel:
mask = mask[..., 0]
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(value_shape)
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(
axis, collections.Iterable), ('axis needs to be either an iterable, integer or "None"')
broadcast_factor = 1.
for axis_ in axis:
value_size = value_shape[axis_]
mask_size = mask_shape[axis_]
if mask_size == 1:
broadcast_factor *= value_size
else:
assert mask_size == value_size
axis = tuple(axis)
return (np.sum(mask * value, axis=axis) / (np.sum(mask, axis=axis) * broadcast_factor + eps))
from .inject_openfold import inject_openfold
from .inject_fastnn import inject_fastnn
__all__ = ['inject_openfold']
\ No newline at end of file
__all__ = ['inject_fastnn']
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable, Optional
BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
checkpoint = torch.utils.checkpoint.checkpoint
return checkpoint
@torch.jit.ignore
def checkpoint_blocks(
blocks: List[Callable],
args: BLOCK_ARGS,
blocks_per_ckpt: Optional[int],
) -> BLOCK_ARGS:
"""
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
Implements Subsection 1.11.8
Args:
blocks:
List of blocks
args:
Tuple of arguments for the first block.
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to fewer
checkpoints, and trades memory for speed. If None, no checkpointing
is performed.
Returns:
The output of the final block
"""
def wrap(a):
return (a,) if type(a) is not tuple else a
def exec(b, a):
for block in b:
a = wrap(block(*a))
return a
def chunker(s, e):
def exec_sliced(*a):
return exec(blocks[s:e], a)
return exec_sliced
# Avoids mishaps when the blocks take just one argument
args = wrap(args)
if blocks_per_ckpt is None:
return exec(blocks, args)
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
checkpoint = get_checkpoint_fn()
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
args = checkpoint(chunker(s, e), *args)
args = wrap(args)
return args
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict
from fastfold.common import protein
import fastfold.common.residue_constants as rc
from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import (
batched_gather,
one_hot,
tree_map,
tensor_tree_map,
)
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
all_atom_masks[..., cb_idx],
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
batch["residx_atom37_to_atom14"],
dim=-2,
no_batch_dims=len(atom14.shape[:-2]),
)
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
return atom37_data
def build_template_angle_feat(template_feats):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats[
"template_alt_torsion_angles_sin_cos"
]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(
*torsion_angles_sin_cos.shape[:-2], 14
),
alt_torsion_angles_sin_cos.reshape(
*alt_torsion_angles_sin_cos.shape[:-2], 14
),
torsion_angles_mask,
],
dim=-1,
)
return template_angle_feat
def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"],
rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
to_concat.append(
aatype_one_hot[..., None, :, :].expand(
*aatype_one_hot.shape[:-2], n_res, -1, -1
)
)
to_concat.append(
aatype_one_hot[..., None, :].expand(
*aatype_one_hot.shape[:-2], -1, n_res, -1
)
)
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
rigids = Rigid.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
c_xyz=batch["template_all_atom_positions"][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
)
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]
return act
def build_extra_msa_feat(batch):
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [
msa_1hot,
batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1),
]
return torch.cat(msa_feat, dim=-1)
def torsion_angles_to_frames(
r: Rigid,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r = r.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
# [*, N, 14, 4, 4]
default_4x4 = default_frames[aatype, ...]
# [*, N, 14]
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
lambda x: torch.sum(x, dim=-1)
)
# [*, N, 14, 1]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
return pred_positions
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from dataclasses import dataclass
from functools import partial
import numpy as np
import torch
from typing import Union, List
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
)
LinearMHAOutputWeight = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
v, top_layer=False
).items()
}
flat.update(sub_flat)
else:
k = "/" + k if not top_layer else k
flat[k] = v
return flat
def stacked(param_dict_list, out=None):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
)
out[k] = stacked_param
return out
def assign(translation_dict, orig_weights):
for k, param in translation_dict.items():
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
ref = [ref]
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
p.copy_(w)
except:
print(k)
print(ref[0].shape)
print(weights[0].shape)
raise
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
}
AttentionParams = lambda att: {
"query_w": LinearWeightMHA(att.linear_q.weight),
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
AttentionGatedParams = lambda att: dict(
**AttentionParams(att),
**{
"gating_w": LinearWeightMHA(att.linear_g.weight),
"gating_b": LinearBiasMHA(att.linear_g.bias),
},
)
GlobalAttentionParams = lambda att: dict(
AttentionGatedParams(att),
key_w=LinearWeight(att.linear_k.weight),
value_w=LinearWeight(att.linear_v.weight),
)
TriAttParams = lambda tri_att: {
"query_norm": LayerNormParams(tri_att.layer_norm),
"feat_2d_weights": LinearWeight(tri_att.linear.weight),
"attention": AttentionGatedParams(tri_att.mha),
}
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
"right_projection": LinearParams(tri_mul.linear_b_p),
"left_gate": LinearParams(tri_mul.linear_a_g),
"right_gate": LinearParams(tri_mul.linear_b_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p),
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
"transition1": LinearParams(pt.linear_1),
"transition2": LinearParams(pt.linear_2),
}
MSAAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAColAttParams = lambda matt: {
"query_norm": LayerNormParams(matt._msa_att.layer_norm_m),
"attention": AttentionGatedParams(matt._msa_att.mha),
}
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
**MSAAttParams(matt),
**{
"feat_2d_norm": LayerNormParams(matt.layer_norm_z),
"feat_2d_weights": LinearWeight(matt.linear_z.weight),
},
)
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
TemplatePairBlockParams = lambda b: {
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"pair_transition": PairTransitionParams(b.pair_transition),
}
MSATransitionParams = lambda m: {
"input_layer_norm": LayerNormParams(m.layer_norm),
"transition1": LinearParams(m.linear_1),
"transition2": LinearParams(m.linear_2),
}
OuterProductMeanParams = lambda o: {
"layer_norm_input": LayerNormParams(o.layer_norm),
"left_projection": LinearParams(o.linear_1),
"right_projection": LinearParams(o.linear_2),
"output_w": LinearWeightOPM(o.linear_out.weight),
"output_b": LinearBias(o.linear_out.bias),
}
def EvoformerBlockParams(b, is_extra_msa=False):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
col_att_name = "msa_column_attention"
msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
"transition_2": LinearParams(sm.transition.layers[0].linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
############################
# translations dict overflow
############################
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_templ = [
"model_3",
"model_4",
"model_5",
"model_3_ptm",
"model_4_ptm",
"model_5_ptm",
]
if version in no_templ:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if "template_" in k:
evo_dict.pop(k)
if "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
# Sanity check
keys = list(data.keys())
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
assign(flat, data)
......@@ -6,7 +6,7 @@ import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model import MSAStack, OutProductMean, PairStack
from fastfold.model.fastnn import MSAStack, OutProductMean, PairStack
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.distributed.comm import gather, scatter
......@@ -179,19 +179,19 @@ def copy_para(block_fast, block_ori):
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition)
def inject_openfold(model):
def inject_fastnn(model):
with torch.no_grad():
fastfold_blocks = nn.ModuleList()
for block_id, openfold_block in enumerate(model.evoformer.blocks):
c_m = openfold_block.msa_att_row.c_in
c_z = openfold_block.msa_att_row.c_z
for block_id, ori_block in enumerate(model.evoformer.blocks):
c_m = ori_block.msa_att_row.c_in
c_z = ori_block.msa_att_row.c_z
fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.evoformer.blocks) -
1))
copy_para(fastfold_block, openfold_block)
copy_para(fastfold_block, ori_block)
fastfold_blocks.append(fastfold_block)
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np
import torch
def rot_matmul(
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
b: [*, 3, 3] right multiplicand
Returns:
The product ab
"""
row_1 = torch.stack(
[
a[..., 0, 0] * b[..., 0, 0]
+ a[..., 0, 1] * b[..., 1, 0]
+ a[..., 0, 2] * b[..., 2, 0],
a[..., 0, 0] * b[..., 0, 1]
+ a[..., 0, 1] * b[..., 1, 1]
+ a[..., 0, 2] * b[..., 2, 1],
a[..., 0, 0] * b[..., 0, 2]
+ a[..., 0, 1] * b[..., 1, 2]
+ a[..., 0, 2] * b[..., 2, 2],
],
dim=-1,
)
row_2 = torch.stack(
[
a[..., 1, 0] * b[..., 0, 0]
+ a[..., 1, 1] * b[..., 1, 0]
+ a[..., 1, 2] * b[..., 2, 0],
a[..., 1, 0] * b[..., 0, 1]
+ a[..., 1, 1] * b[..., 1, 1]
+ a[..., 1, 2] * b[..., 2, 1],
a[..., 1, 0] * b[..., 0, 2]
+ a[..., 1, 1] * b[..., 1, 2]
+ a[..., 1, 2] * b[..., 2, 2],
],
dim=-1,
)
row_3 = torch.stack(
[
a[..., 2, 0] * b[..., 0, 0]
+ a[..., 2, 1] * b[..., 1, 0]
+ a[..., 2, 2] * b[..., 2, 0],
a[..., 2, 0] * b[..., 0, 1]
+ a[..., 2, 1] * b[..., 1, 1]
+ a[..., 2, 2] * b[..., 2, 1],
a[..., 2, 0] * b[..., 0, 2]
+ a[..., 2, 1] * b[..., 1, 2]
+ a[..., 2, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack([row_1, row_2, row_3], dim=-2)
def rot_vec_mul(
r: torch.Tensor,
t: torch.Tensor
) -> torch.Tensor:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
t: [*, 3] coordinate tensors
Returns:
[*, 3] rotated coordinates
"""
x = t[..., 0]
y = t[..., 1]
z = t[..., 2]
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
rots = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
return rots
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
trans = torch.zeros(
(*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
return trans
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
) -> torch.Tensor:
quat = torch.zeros(
(*batch_dims, 4),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
with torch.no_grad():
quat[..., 0] = 1
return quat
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = np.zeros((4, 4))
for pair in pairs:
key, value = pair
ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value
return mat
_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def rot_to_quat(
rot: torch.Tensor,
):
if(rot.shape[-2:] != (3, 3)):
raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
]
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1]
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
[ 0,-1, 0, 0],
[ 0, 0,-1, 0],
[ 0, 0, 0,-1]]
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
[ 1, 0, 0, 0],
[ 0, 0, 0, 1],
[ 0, 0,-1, 0]]
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
[ 0, 0, 0,-1],
[ 1, 0, 0, 0],
[ 0, 1, 0, 0]]
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
[ 0, 0, 1, 0],
[ 0,-1, 0, 0],
[ 1, 0, 0, 0]]
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat1[..., :, None, None] *
quat2[..., None, :, None],
dim=(-3, -2)
)
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
quat[..., :, None, None] *
vec[..., None, :, None],
dim=(-3, -2)
)
def invert_rot_mat(rot_mat: torch.Tensor):
return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor):
quat_prime = quat.clone()
quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
return inv
class Rotation:
"""
A 3D rotation. Depending on how the object is initialized, the
rotation is represented by either a rotation matrix or a
quaternion, though both formats are made available by helper functions.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def __init__(self,
rot_mats: Optional[torch.Tensor] = None,
quats: Optional[torch.Tensor] = None,
normalize_quats: bool = True,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if((rot_mats is None and quats is None) or
(rot_mats is not None and quats is not None)):
raise ValueError("Exactly one input argument must be specified")
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
(quats is not None and quats.shape[-1] != 4)):
raise ValueError(
"Incorrectly shaped rotation matrix or quaternion"
)
# Force full-precision
if(quats is not None):
quats = quats.to(dtype=torch.float32)
if(rot_mats is not None):
rot_mats = rot_mats.to(dtype=torch.float32)
if(quats is not None and normalize_quats):
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
self._rot_mats = rot_mats
self._quats = quats
@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rotation:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if(fmt == "rot_mat"):
rot_mats = identity_rot_mats(
shape, dtype, device, requires_grad,
)
return Rotation(rot_mats=rot_mats, quats=None)
elif(fmt == "quat"):
quats = identity_quats(shape, dtype, device, requires_grad)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError(f"Invalid format: f{fmt}")
# Magic methods
def __getitem__(self, index: Any) -> Rotation:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if type(index) != tuple:
index = (index,)
if(self._rot_mats is not None):
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
return Rotation(rot_mats=rot_mats)
elif(self._quats is not None):
quats = self._quats[index + (slice(None),)]
return Rotation(quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __mul__(self,
right: torch.Tensor,
) -> Rotation:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
if(self._rot_mats is not None):
rot_mats = self._rot_mats * right[..., None, None]
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats * right[..., None]
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def __rmul__(self,
left: torch.Tensor,
) -> Rotation:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
# Properties
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s = None
if(self._quats is not None):
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.dtype
elif(self._quats is not None):
return self._quats.dtype
else:
raise ValueError("Both rotations are None")
@property
def device(self) -> torch.device:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if(self._rot_mats is not None):
return self._rot_mats.device
elif(self._quats is not None):
return self._quats.device
else:
raise ValueError("Both rotations are None")
@property
def requires_grad(self) -> bool:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if(self._rot_mats is not None):
return self._rot_mats.requires_grad
elif(self._quats is not None):
return self._quats.requires_grad
else:
raise ValueError("Both rotations are None")
def get_rot_mats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats = self._rot_mats
if(rot_mats is None):
if(self._quats is None):
raise ValueError("Both rotations are None")
else:
rot_mats = quat_to_rot(self._quats)
return rot_mats
def get_quats(self) -> torch.Tensor:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats = self._quats
if(quats is None):
if(self._rot_mats is None):
raise ValueError("Both rotations are None")
else:
quats = rot_to_quat(self._rot_mats)
return quats
def get_cur_rot(self) -> torch.Tensor:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if(self._rot_mats is not None):
return self._rot_mats
elif(self._quats is not None):
return self._quats
else:
raise ValueError("Both rotations are None")
# Rotation functions
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
normalize_quats: bool = True
) -> Rotation:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats = self.get_quats()
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
return Rotation(
rot_mats=None,
quats=new_quats,
normalize_quats=normalize_quats,
)
def compose_r(self, r: Rotation) -> Rotation:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1 = self.get_rot_mats()
r2 = r.get_rot_mats()
new_rot_mats = rot_matmul(r1, r2)
return Rotation(rot_mats=new_rot_mats, quats=None)
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1 = self.get_quats()
q2 = r.get_quats()
new_quats = quat_multiply(q1, q2)
return Rotation(
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
)
def apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats = self.get_rot_mats()
return rot_vec_mul(rot_mats, pts)
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats = self.get_rot_mats()
inv_rot_mats = invert_rot_mat(rot_mats)
return rot_vec_mul(inv_rot_mats, pts)
def invert(self) -> Rotation:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=invert_rot_mat(self._rot_mats),
quats=None
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=invert_quat(self._quats),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
# "Tensor" stuff
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
if(self._rot_mats is not None):
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
@staticmethod
def cat(
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats = [r.get_rot_mats() for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self,
fn: Callable[torch.Tensor, torch.Tensor]
) -> Rotation:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if(self._rot_mats is not None):
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
rot_mats = torch.stack(
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)
elif(self._quats is not None):
quats = torch.stack(
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
)
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
else:
raise ValueError("Both rotations are None")
def cuda(self) -> Rotation:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.cuda(),
normalize_quats=False
)
else:
raise ValueError("Both rotations are None")
def to(self,
device: Optional[torch.device],
dtype: Optional[torch.dtype]
) -> Rotation:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if(self._rot_mats is not None):
return Rotation(
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
quats=None,
)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.to(device=device, dtype=dtype),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
def detach(self) -> Rotation:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if(self._rot_mats is not None):
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
elif(self._quats is not None):
return Rotation(
rot_mats=None,
quats=self._quats.detach(),
normalize_quats=False,
)
else:
raise ValueError("Both rotations are None")
class Rigid:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
"""
def __init__(self,
rots: Optional[Rotation],
trans: Optional[torch.Tensor],
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
# (we need device, dtype, etc. from at least one input)
batch_dims, dtype, device, requires_grad = None, None, None, None
if(trans is not None):
batch_dims = trans.shape[:-1]
dtype = trans.dtype
device = trans.device
requires_grad = trans.requires_grad
elif(rots is not None):
batch_dims = rots.shape
dtype = rots.dtype
device = rots.device
requires_grad = rots.requires_grad
else:
raise ValueError("At least one input argument must be specified")
if(rots is None):
rots = Rotation.identity(
batch_dims, dtype, device, requires_grad,
)
elif(trans is None):
trans = identity_trans(
batch_dims, dtype, device, requires_grad,
)
if((rots.shape != trans.shape[:-1]) or
(rots.device != trans.device)):
raise ValueError("Rots and trans incompatible")
# Force full precision. Happens to the rotations automatically.
trans = trans.to(dtype=torch.float32)
self._rots = rots
self._trans = trans
@staticmethod
def identity(
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = True,
fmt: str = "quat",
) -> Rigid:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return Rigid(
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
identity_trans(shape, dtype, device, requires_grad),
)
def __getitem__(self,
index: Any,
) -> Rigid:
"""
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
and the translation.
E.g.::
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.get_rots().shape == (2,))
assert(indexed.get_trans().shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
or (3, slice(0, 1, None))
Returns:
The indexed tensor
"""
if type(index) != tuple:
index = (index,)
return Rigid(
self._rots[index],
self._trans[index + (slice(None),)],
)
def __mul__(self,
right: torch.Tensor,
) -> Rigid:
"""
Pointwise left multiplication of the transformation with a tensor.
Can be used to e.g. mask the Rigid.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if not(isinstance(right, torch.Tensor)):
raise TypeError("The other multiplicand must be a Tensor")
new_rots = self._rots * right
new_trans = self._trans * right[..., None]
return Rigid(new_rots, new_trans)
def __rmul__(self,
left: torch.Tensor,
) -> Rigid:
"""
Reverse pointwise multiplication of the transformation with a
tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return self.__mul__(left)
@property
def shape(self) -> torch.Size:
"""
Returns the shape of the shared dimensions of the rotation and
the translation.
Returns:
The shape of the transformation
"""
s = self._trans.shape[:-1]
return s
@property
def device(self) -> torch.device:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return self._trans.device
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
Returns:
The rotation object
"""
return self._rots
def get_trans(self) -> torch.Tensor:
"""
Getter for the translation.
Returns:
The stored translation
"""
return self._trans
def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
) -> Rigid:
"""
Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args:
q_vec: The quaternion update vector.
Returns:
The composed transformation.
"""
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
new_rots = self._rots.compose_q_update_vec(q_vec)
trans_update = self._rots.apply(t_vec)
new_translation = self._trans + trans_update
return Rigid(new_rots, new_translation)
def compose(self,
r: Rigid,
) -> Rigid:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot = self._rots.compose_r(r._rots)
new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans)
def apply(self,
pts: torch.Tensor,
) -> torch.Tensor:
"""
Applies the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor.
Returns:
The transformed points.
"""
rotated = self._rots.apply(pts)
return rotated + self._trans
def invert_apply(self,
pts: torch.Tensor
) -> torch.Tensor:
"""
Applies the inverse of the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor
Returns:
The transformed points.
"""
pts = pts - self._trans
return self._rots.invert_apply(pts)
def invert(self) -> Rigid:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv = self._rots.invert()
trn_inv = rot_inv.apply(self._trans)
return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self,
fn: Callable[torch.Tensor, torch.Tensor]
) -> Rigid:
"""
Apply a Tensor -> Tensor function to underlying translation and
rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The transformed Rigid object
"""
new_rots = self._rots.map_tensor_fn(fn)
new_trans = torch.stack(
list(map(fn, torch.unbind(self._trans, dim=-1))),
dim=-1
)
return Rigid(new_rots, new_trans)
def to_tensor_4x4(self) -> torch.Tensor:
"""
Converts a transformation to a homogenous transformation tensor.
Returns:
A [*, 4, 4] homogenous transformation tensor
"""
tensor = self._trans.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._rots.get_rot_mats()
tensor[..., :3, 3] = self._trans
tensor[..., 3, 3] = 1
return tensor
@staticmethod
def from_tensor_4x4(
t: torch.Tensor
) -> Rigid:
"""
Constructs a transformation from a homogenous transformation
tensor.
Args:
t: [*, 4, 4] homogenous transformation tensor
Returns:
T object with shape [*]
"""
if(t.shape[-2:] != (4, 4)):
raise ValueError("Incorrectly shaped input tensor")
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
trans = t[..., :3, 3]
return Rigid(rots, trans)
def to_tensor_7(self) -> torch.Tensor:
"""
Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns:
A [*, 7] tensor representation of the transformation
"""
tensor = self._trans.new_zeros((*self.shape, 7))
tensor[..., :4] = self._rots.get_quats()
tensor[..., 4:] = self._trans
return tensor
@staticmethod
def from_tensor_7(
t: torch.Tensor,
normalize_quats: bool = False,
) -> Rigid:
if(t.shape[-1] != 7):
raise ValueError("Incorrectly shaped input tensor")
quats, trans = t[..., :4], t[..., 4:]
rots = Rotation(
rot_mats=None,
quats=quats,
normalize_quats=normalize_quats
)
return Rigid(rots, trans)
@staticmethod
def from_3_points(
p_neg_x_axis: torch.Tensor,
origin: torch.Tensor,
p_xy_plane: torch.Tensor,
eps: float = 1e-8
) -> Rigid:
"""
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
Args:
p_neg_x_axis: [*, 3] coordinates
origin: [*, 3] coordinates used as frame origins
p_xy_plane: [*, 3] coordinates
eps: Small epsilon value
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1))
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
@staticmethod
def cat(
ts: Sequence[Rigid],
dim: int,
) -> Rigid:
"""
Concatenates transformations along a new dimension.
Args:
ts:
A list of T objects
dim:
The dimension along which the transformations should be
concatenated
Returns:
A concatenated transformation object
"""
rots = Rotation.cat([t._rots for t in ts], dim)
trans = torch.cat(
[t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
)
return Rigid(rots, trans)
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
"""
Applies a Rotation -> Rotation function to the stored rotation
object.
Args:
fn: A function of type Rotation -> Rotation
Returns:
A transformation object with a transformed rotation.
"""
return Rigid(fn(self._rots), self._trans)
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
"""
Applies a Tensor -> Tensor function to the stored translation.
Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return Rigid(self._rots, fn(self._trans))
def scale_translation(self, trans_scale_factor: float) -> Rigid:
"""
Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns:
A transformation object with a scaled translation.
"""
fn = lambda t: t * trans_scale_factor
return self.apply_trans_fn(fn)
def stop_rot_gradient(self) -> Rigid:
"""
Detaches the underlying rotation object
Returns:
A transformation object with detached rotations
"""
fn = lambda r: r.detach()
return self.apply_rot_fn(fn)
@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
"""
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you
provide the atom positions in the non-standard way, the N atom will
end up not at [-0.527250, 1.359329, 0.0] but instead at
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
your code.
Args:
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
Returns:
A transformation object. After applying the translation and
rotation to the reference backbone, the coordinates will
approximately equal to the input coordinates.
"""
translation = -1 * ca_xyz
n_xyz = n_xyz + translation
c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
zeros = sin_c1.new_zeros(sin_c1.shape)
ones = sin_c1.new_ones(sin_c1.shape)
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
c1_rots[..., 0, 0] = cos_c1
c1_rots[..., 0, 1] = -1 * sin_c1
c1_rots[..., 1, 0] = sin_c1
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
sin_c2 = c_z / norm
cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c1_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
sin_n = -n_z / norm
cos_n = n_y / norm
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n
rots = rot_matmul(n_rots, c_rots)
rots = rots.transpose(-1, -2)
translation = -1 * translation
rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, translation)
def cuda(self) -> Rigid:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return Rigid(self._rots.cuda(), self._trans.cuda())
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-4):
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
boundaries = torch.linspace(
min_bin, max_bin, no_bins - 1, device=pts.device
)
dists = torch.sqrt(
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
)
return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data, inds, dim=0, no_batch_dims=0):
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [
slice(None) for _ in range(len(data.shape) - no_batch_dims)
]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):
new_dict = {}
for k, v in dic.items():
if type(v) is dict:
new_dict[k] = dict_map(fn, v, leaf_type)
else:
new_dict[k] = tree_map(fn, v, leaf_type)
return new_dict
def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif isinstance(tree, leaf_type):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
......@@ -18,38 +18,70 @@ import os
import random
import sys
import time
from datetime import date
import numpy as np
import torch
from fastfold.model.hub import AlphaFold
import fastfold
import openfold.np.relax.relax as relax
from fastfold.utils import inject_openfold
from openfold.config import model_config
from openfold.data import data_pipeline, feature_pipeline, templates
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import protein, residue_constants
from openfold.utils.import_weights import import_jax_weights_
from openfold.utils.tensor_utils import tensor_tree_map
from scripts.utils import add_data_args
import fastfold.relax.relax as relax
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.utils import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'--uniref90_database_path',
type=str,
default=None,
)
parser.add_argument(
'--mgnify_database_path',
type=str,
default=None,
)
parser.add_argument(
'--pdb70_database_path',
type=str,
default=None,
)
parser.add_argument(
'--uniclust30_database_path',
type=str,
default=None,
)
parser.add_argument(
'--bfd_database_path',
type=str,
default=None,
)
parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer')
parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits')
parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch')
parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign')
parser.add_argument(
'--max_template_date',
type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument('--obsolete_pdbs_path', type=str, default=None)
parser.add_argument('--release_dates_path', type=str, default=None)
def main(args):
# init distributed for Dynamic Axial Parallelism
local_rank = int(os.getenv('LOCAL_RANK', -1))
if local_rank != -1:
distributed_inference_ = True
fastfold.distributed.init_dap()
else:
distributed_inference_ = False
fastfold.distributed.init_dap()
config = model_config(args.model_name)
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_openfold(model)
model = inject_fastnn(model)
model = model.eval()
#script_preset_(model)
model = model.cuda()
......@@ -62,7 +94,12 @@ def main(args):
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path)
use_small_bfd = (args.bfd_database_path is None)
use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None)
if use_small_bfd:
assert args.bfd_database_path is not None
else:
assert args.bfd_database_path is not None
assert args.uniclust30_database_path is not None
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
......@@ -87,7 +124,7 @@ def main(args):
for tag, seq in zip(tags, seqs):
batch = [None]
if (not distributed_inference_) or (torch.distributed.get_rank() == 0):
if torch.distributed.get_rank() == 0:
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
......@@ -125,8 +162,8 @@ def main(args):
batch = [processed_feature_dict]
if distributed_inference_:
torch.distributed.broadcast_object_list(batch, src=0)
torch.distributed.broadcast_object_list(batch, src=0)
batch = batch[0]
print("Executing model...")
......@@ -138,10 +175,9 @@ def main(args):
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
if distributed_inference_:
torch.distributed.barrier()
torch.distributed.barrier()
if (not distributed_inference_) or (torch.distributed.get_rank() == 0):
if torch.distributed.get_rank() == 0:
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
......@@ -177,8 +213,7 @@ def main(args):
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if distributed_inference_:
torch.distributed.barrier()
torch.distributed.barrier()
if __name__ == "__main__":
......@@ -212,21 +247,20 @@ if __name__ == "__main__":
default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params""")
./data/params""")
parser.add_argument("--cpus",
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument('--preset',
type=str,
default='reduced_dbs',
default='full_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--data_random_seed', type=str, default=None)
add_data_args(parser)
args = parser.parse_args()
if (args.param_path is None):
args.param_path = os.path.join("openfold", "resources", "params",
"params_" + args.model_name + ".npz")
args.param_path = os.path.join("data", "params", "params_" + args.model_name + ".npz")
main(args)
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips all required data for AlphaFold.
#
# Usage: bash download_all_data.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs.
if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]]
then
echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized."
exit 1
fi
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
echo "Downloading AlphaFold parameters..."
bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}"
if [[ "${DOWNLOAD_MODE}" = reduced_dbs ]] ; then
echo "Downloading Small BFD..."
bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}"
else
echo "Downloading BFD..."
bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}"
fi
echo "Downloading MGnify..."
bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB70..."
bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
# UniProt and PDB SeqRes for multimer version
# echo "Downloading UniProt..."
# bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
# echo "Downloading PDB SeqRes..."
# bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the AlphaFold parameters.
#
# Usage: bash download_alphafold_params.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}" --preserve-permissions
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the BFD database for AlphaFold.
#
# Usage: bash download_bfd.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/bfd"
# Mirror of:
# https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz.
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}"
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the MGnify database for AlphaFold.
#
# Usage: bash download_mgnify.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
# Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${BASENAME}"
popd
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB70 database for AlphaFold.
#
# Usage: bash download_pdb70.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb70"
SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}"
rm "${ROOT_DIR}/${BASENAME}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and flattens the PDB database for AlphaFold.
#
# Usage: bash download_pdb_mmcif.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
if ! command -v rsync &> /dev/null ; then
echo "Error: rsync could not be found. Please install rsync."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_mmcif"
RAW_DIR="${ROOT_DIR}/raw"
MMCIF_DIR="${ROOT_DIR}/mmcif_files"
echo "Running rsync to fetch all mmCIF files (note that the rsync progress estimate might be inaccurate)..."
echo "If the download speed is too slow, try changing the mirror to:"
echo " * rsync.ebi.ac.uk::pub/databases/pdb/data/structures/divided/mmCIF/ (Europe)"
echo " * ftp.pdbj.org::ftp_data/structures/divided/mmCIF/ (Asia)"
echo "or see https://www.wwpdb.org/ftp/pdb-ftp-sites for more download options."
mkdir --parents "${RAW_DIR}"
rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 \
rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ \
"${RAW_DIR}"
echo "Unzipping all mmCIF files..."
find "${RAW_DIR}/" -type f -iname "*.gz" -exec gunzip {} +
echo "Flattening all mmCIF files..."
mkdir --parents "${MMCIF_DIR}"
find "${RAW_DIR}" -type d -empty -delete # Delete empty directories.
for subdir in "${RAW_DIR}"/*; do
mv "${subdir}/"*.cif "${MMCIF_DIR}"
done
# Delete empty download directory structure.
find "${RAW_DIR}" -type d -empty -delete
aria2c "ftp://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat" --dir="${ROOT_DIR}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the Small BFD database for AlphaFold.
#
# Usage: bash download_small_bfd.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/small_bfd"
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
trickle -s -u 1024 -d 10240 aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${BASENAME}"
popd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment