Commit 6918c085 authored by Augustin Zidek's avatar Augustin Zidek Committed by Copybara-Service
Browse files

Fix incorrect type annotations.

PiperOrigin-RevId: 512029306
Change-Id: If2237867d38dd5e0002cd516f68f7eaa67b63443
parent a54b34c1
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Ops for all atom representations."""
from typing import Dict, Text
from typing import Dict, Optional
from alphafold.common import residue_constants
from alphafold.model import geometry
......@@ -276,7 +276,7 @@ def atom37_to_frames(
aatype: jnp.ndarray, # (...)
all_atom_positions: geometry.Vec3Array, # (..., 37)
all_atom_mask: jnp.ndarray, # (..., 37)
) -> Dict[Text, jnp.ndarray]:
) -> Dict[str, jnp.ndarray]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
......@@ -498,7 +498,7 @@ def between_residue_bond_loss(
residue_index: jnp.ndarray, # (N)
aatype: jnp.ndarray, # (N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0) -> Dict[Text, jnp.ndarray]:
tolerance_factor_hard=12.0) -> Dict[str, jnp.ndarray]:
"""Flat-bottom loss to penalize structural violations between residues."""
assert len(pred_atom_positions.shape) == 2
assert len(pred_atom_mask.shape) == 2
......@@ -600,7 +600,7 @@ def between_residue_clash_loss(
residue_index: jnp.ndarray, # (N)
asym_id: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]:
overlap_tolerance_hard=1.5) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes between residues."""
assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2
......@@ -682,7 +682,7 @@ def within_residue_violations(
dists_lower_bound: jnp.ndarray, # (N, 14, 14)
dists_upper_bound: jnp.ndarray, # (N, 14, 14)
tighten_bounds_for_loss=0.0,
) -> Dict[Text, jnp.ndarray]:
) -> Dict[str, jnp.ndarray]:
"""Find within-residue violations."""
assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2
......@@ -789,7 +789,7 @@ def frame_aligned_point_error(
pred_positions: geometry.Vec3Array, # shape (num_positions)
target_positions: geometry.Vec3Array, # shape (num_positions)
positions_mask: jnp.ndarray, # shape (num_positions)
pair_mask: jnp.ndarray, # shape (num_frames, num_posiitons)
pair_mask: Optional[jnp.ndarray], # shape (num_frames, num_posiitons)
l1_clamp_distance: float,
length_scale=20.,
epsilon=1e-4) -> jnp.ndarray: # shape ()
......
......@@ -66,7 +66,7 @@ class RunModel:
def __init__(self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
params: Optional[Mapping[str, Mapping[str, jax.Array]]] = None):
self.config = config
self.params = params
self.multimer_mode = config.model.global_config.multimer_mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment