Unverified Commit b7ee0ff3 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

fix wrong para import and update protein class (#130)

parent a80d5263
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Protein data type.""" """Protein data type."""
import dataclasses import dataclasses
import io import io
...@@ -22,10 +23,15 @@ from fastfold.common import residue_constants ...@@ -22,10 +23,15 @@ from fastfold.common import residue_constants
from Bio.PDB import PDBParser from Bio.PDB import PDBParser
import numpy as np import numpy as np
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict. ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01 PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Protein: class Protein:
...@@ -46,24 +52,31 @@ class Protein: ...@@ -46,24 +52,31 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed. # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res] residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units), # B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean # representing the displacement of the residue from its ground truth mean
# value. # value.
b_factors: np.ndarray # [num_res, num_atom_type] b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object. """Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored. non-standard atoms will be ignored.
Args: Args:
pdb_str: The contents of the pdb file pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which chain_id: If chain_id is specified (e.g. A), then only that chain is
will be parsed). If chain_id is specified (e.g. A), then only that chain parsed. Else, all chains are parsed.
is parsed.
Returns: Returns:
A new `Protein` parsed from the pdb contents. A new `Protein` parsed from the pdb contents.
""" """
...@@ -72,32 +85,33 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -72,32 +85,33 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
structure = parser.get_structure("none", pdb_fh) structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models()) models = list(structure.get_models())
if len(models) != 1: if len(models) != 1:
raise ValueError(f"Only single model PDBs are supported. Found {len(models)} models.") raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0] model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError("Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains.")
else:
chain = chains[0]
atom_positions = [] atom_positions = []
aatype = [] aatype = []
atom_mask = [] atom_mask = []
residue_index = [] residue_index = []
chain_ids = []
b_factors = [] b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain: for res in chain:
if res.id[2] != " ": if res.id[2] != " ":
raise ValueError(f"PDB contains an insertion code at chain {chain.id} and residue " raise ValueError(
f"index {res.id[1]}. These are not supported.") f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X") res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(res_shortname, restype_idx = residue_constants.restype_order.get(
residue_constants.restype_num) res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3)) pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,)) mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,)) res_b_factors = np.zeros((residue_constants.atom_type_num,))
...@@ -106,28 +120,40 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: ...@@ -106,28 +120,40 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
continue continue
pos[residue_constants.atom_order[atom.name]] = atom.coord pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0 mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5: if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it. # If no known atom positions are reported for the residue then skip it.
continue continue
aatype.append(restype_idx) aatype.append(restype_idx)
atom_positions.append(pos) atom_positions.append(pos)
atom_mask.append(mask) atom_mask.append(mask)
residue_index.append(res.id[1]) residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors) b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein( return Protein(
atom_positions=np.array(atom_positions), atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask), atom_mask=np.array(atom_mask),
aatype=np.array(aatype), aatype=np.array(aatype),
residue_index=np.array(residue_index), residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors), b_factors=np.array(b_factors),
) )
def from_proteinnet_string(proteinnet_str: str) -> Protein: def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)' tag_re = r'(\[[A-Z]+\]\n)'
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] tags = [
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C'] atoms = ['N', 'CA', 'C']
...@@ -135,32 +161,34 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -135,32 +161,34 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
atom_positions = None atom_positions = None
atom_mask = None atom_mask = None
for g in groups: for g in groups:
if ("[PRIMARY]" == g[0]): if("[PRIMARY]" == g[0]):
seq = g[1][0].strip() seq = g[1][0].strip()
for i in range(len(seq)): for i in range(len(seq)):
if (seq[i] not in residue_constants.restypes): if(seq[i] not in residue_constants.restypes):
seq[i] = 'X' seq[i] = 'X'
aatype = np.array([ aatype = np.array([
residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) residue_constants.restype_order.get(
for res_symbol in seq res_symbol, residue_constants.restype_num
) for res_symbol in seq
]) ])
elif ("[TERTIARY]" == g[0]): elif("[TERTIARY]" == g[0]):
tertiary = [] tertiary = []
for axis in range(3): for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split()))) tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary) tertiary_np = np.array(tertiary)
atom_positions = np.zeros( atom_positions = np.zeros(
(len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32) (len(tertiary[0])//3, residue_constants.atom_type_num, 3)
).astype(np.float32)
for i, atom in enumerate(atoms): for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (np.transpose( atom_positions[:, residue_constants.atom_order[atom], :] = (
tertiary_np[:, i::3])) np.transpose(tertiary_np[:, i::3])
)
atom_positions *= PICO_TO_ANGSTROM atom_positions *= PICO_TO_ANGSTROM
elif ("[MASK]" == g[0]): elif("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip()))) mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros(( atom_mask = np.zeros(
len(mask), (len(mask), residue_constants.atom_type_num,)
residue_constants.atom_type_num, ).astype(np.float32)
)).astype(np.float32)
for i, atom in enumerate(atoms): for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1 atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None] atom_mask *= mask[..., None]
...@@ -174,12 +202,18 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -174,12 +202,18 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
) )
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def to_pdb(prot: Protein) -> str: def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string. """Converts a `Protein` instance to a PDB string.
Args: Args:
prot: The protein to convert to PDB. prot: The protein to convert to PDB.
Returns: Returns:
PDB string. PDB string.
""" """
...@@ -193,19 +227,43 @@ def to_pdb(prot: Protein) -> str: ...@@ -193,19 +227,43 @@ def to_pdb(prot: Protein) -> str:
aatype = prot.aatype aatype = prot.aatype
atom_positions = prot.atom_positions atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32) residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num): if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.") raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1") pdb_lines.append("MODEL 1")
atom_index = 1 atom_index = 1
chain_id = "A" last_chain_index = chain_index[0]
# Add all atom sites. # Add all atom sites.
for i in range(aatype.shape[0]): for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i]) res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], for atom_name, pos, mask, b_factor in zip(
b_factors[i]): atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5: if mask < 0.5:
continue continue
...@@ -214,40 +272,47 @@ def to_pdb(prot: Protein) -> str: ...@@ -214,40 +272,47 @@ def to_pdb(prot: Protein) -> str:
alt_loc = "" alt_loc = ""
insertion_code = "" insertion_code = ""
occupancy = 1.00 occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works. element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = "" charge = ""
# PDB is a columnar format, every space matters here! # PDB is a columnar format, every space matters here!
atom_line = (f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" atom_line = (
f"{res_name_3:>3} {chain_id:>1}" f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} " f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} " f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}") f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line) pdb_lines.append(atom_line)
atom_index += 1 atom_index += 1
# Close the chain. # Close the final chain.
chain_end = "TER" pdb_lines.append(
chain_termination_line = (f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} " _chain_end(
f"{chain_id:>1}{residue_index[-1]:>4}") atom_index,
pdb_lines.append(chain_termination_line) res_1to3(aatype[-1]),
pdb_lines.append("ENDMDL") chain_ids[chain_index[-1]],
residue_index[-1]
)
)
pdb_lines.append("ENDMDL")
pdb_lines.append("END") pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines) # Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray: def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask. """Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are `Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids. that should be present in the given sequence of amino acids.
Args: Args:
prot: `Protein` whose fields are `numpy.ndarray` objects. prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns: Returns:
An ideal atom mask. An ideal atom mask.
""" """
...@@ -258,24 +323,37 @@ def from_prediction( ...@@ -258,24 +323,37 @@ def from_prediction(
features: FeatureDict, features: FeatureDict,
result: ModelOutput, result: ModelOutput,
b_factors: Optional[np.ndarray] = None, b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = False,
) -> Protein: ) -> Protein:
"""Assembles a protein from a prediction. """Assembles a protein from a prediction.
Args: Args:
features: Dictionary holding model inputs. features: Dictionary holding model inputs.
result: Dictionary holding model outputs. result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein. b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns: Returns:
A protein instance. A protein instance.
""" """
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None: if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"]) b_factors = np.zeros_like(result["final_atom_mask"])
return Protein( return Protein(
aatype=features["aatype"], aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"], atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"], atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1, residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors, b_factors=b_factors,
) )
...@@ -249,7 +249,7 @@ def run_msa_tool( ...@@ -249,7 +249,7 @@ def run_msa_tool(
max_sto_sequences: Optional[int] = None, max_sto_sequences: Optional[int] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first.""" """Runs an MSA tool, checking if output already exists first."""
if(msa_format == "sto" and max_sto_sequences is not None): if(msa_format == "sto"):
result = msa_runner.query(fasta_path, max_sto_sequences)[0] result = msa_runner.query(fasta_path, max_sto_sequences)[0]
else: else:
result = msa_runner.query(fasta_path) result = msa_runner.query(fasta_path)
......
...@@ -606,3 +606,22 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -606,3 +606,22 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights # Set weights
assign(flat, data) assign(flat, data)
if is_fused_triangle_multiplication():
# (NOTE) in multimer v3, alphafold use fused tri, so need change left/right here
for b in model.template_embedder.template_pair_stack.blocks:
_change_tri_mul_in_left_right(b.tri_mul_in)
for b in model.extra_msa_stack.blocks:
_change_tri_mul_in_left_right(b.core.tri_mul_in)
for b in model.evoformer.blocks:
_change_tri_mul_in_left_right(b.core.tri_mul_in)
def _change_tri_mul_in_left_right(module):
def _change_para(para):
left_right_para = para.clone().chunk(2, dim=0)
return torch.cat((left_right_para[1], left_right_para[0]), dim=0)
with torch.no_grad():
module.linear_p.weight.copy_(_change_para(module.linear_p.weight))
module.linear_p.bias.copy_(_change_para(module.linear_p.bias))
module.linear_g.weight.copy_(_change_para(module.linear_g.weight))
module.linear_g.bias.copy_(_change_para(module.linear_g.bias))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment