Commit 481b8be2 authored by Rich Evans's avatar Rich Evans Committed by Copybara-Service
Browse files

Update open source between_residue_clash_loss to take into account multiple chains.

PiperOrigin-RevId: 432934821
Change-Id: If5decaf654823a37c4270151039a12606f526f7a
parent 3e046ad9
...@@ -598,6 +598,7 @@ def between_residue_clash_loss( ...@@ -598,6 +598,7 @@ def between_residue_clash_loss(
atom_exists: jnp.ndarray, # (N, 14) atom_exists: jnp.ndarray, # (N, 14)
atom_radius: jnp.ndarray, # (N, 14) atom_radius: jnp.ndarray, # (N, 14)
residue_index: jnp.ndarray, # (N) residue_index: jnp.ndarray, # (N)
asym_id: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5, overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]: overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]:
"""Loss to penalize steric clashes between residues.""" """Loss to penalize steric clashes between residues."""
...@@ -624,8 +625,9 @@ def between_residue_clash_loss( ...@@ -624,8 +625,9 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash. # Backbone C--N bond between subsequent residues is no clash.
c_one_hot = jax.nn.one_hot(2, num_classes=14) c_one_hot = jax.nn.one_hot(2, num_classes=14)
n_one_hot = jax.nn.one_hot(0, num_classes=14) n_one_hot = jax.nn.one_hot(0, num_classes=14)
neighbour_mask = ((residue_index[:, None, None, None] + neighbour_mask = ((residue_index[:, None] + 1) == residue_index[None, :])
1) == residue_index[None, :, None, None]) neighbour_mask &= (asym_id[:, None] == asym_id[None, :])
neighbour_mask = neighbour_mask[..., None, None]
c_n_bonds = neighbour_mask * c_one_hot[None, None, :, c_n_bonds = neighbour_mask * c_one_hot[None, None, :,
None] * n_one_hot[None, None, None, :] None] * n_one_hot[None, None, None, :]
dists_mask *= (1. - c_n_bonds) dists_mask *= (1. - c_n_bonds)
......
...@@ -664,7 +664,8 @@ class StructureModule(hk.Module): ...@@ -664,7 +664,8 @@ class StructureModule(hk.Module):
residue_index=residue_index, residue_index=residue_index,
mask=pred_mask, mask=pred_mask,
pred_positions=pred_positions, pred_positions=pred_positions,
config=self.config) config=self.config,
asym_id=batch['asym_id'])
sidechains = value['sidechains'] sidechains = value['sidechains']
...@@ -890,7 +891,8 @@ def find_structural_violations( ...@@ -890,7 +891,8 @@ def find_structural_violations(
residue_index: jnp.ndarray, residue_index: jnp.ndarray,
mask: jnp.ndarray, mask: jnp.ndarray,
pred_positions: geometry.Vec3Array, # (N, 14) pred_positions: geometry.Vec3Array, # (N, 14)
config: ml_collections.ConfigDict config: ml_collections.ConfigDict,
asym_id: jnp.ndarray,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Computes several checks for structural Violations.""" """Computes several checks for structural Violations."""
...@@ -921,7 +923,8 @@ def find_structural_violations( ...@@ -921,7 +923,8 @@ def find_structural_violations(
atom_radius=atom_radius, atom_radius=atom_radius,
residue_index=residue_index, residue_index=residue_index,
overlap_tolerance_soft=config.clash_overlap_tolerance, overlap_tolerance_soft=config.clash_overlap_tolerance,
overlap_tolerance_hard=config.clash_overlap_tolerance) overlap_tolerance_hard=config.clash_overlap_tolerance,
asym_id=asym_id)
# Compute all within-residue violations (clashes, # Compute all within-residue violations (clashes,
# bond length and angle violations). # bond length and angle violations).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment