"vscode:/vscode.git/clone" did not exist on "419974bbef580ee6cbf8ea1aedcdcc3ddaa5452e"
Commit 6918c085 authored by Augustin Zidek's avatar Augustin Zidek Committed by Copybara-Service
Browse files

Fix incorrect type annotations.

PiperOrigin-RevId: 512029306
Change-Id: If2237867d38dd5e0002cd516f68f7eaa67b63443
parent a54b34c1
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Ops for all atom representations.""" """Ops for all atom representations."""
from typing import Dict, Text from typing import Dict, Optional
from alphafold.common import residue_constants from alphafold.common import residue_constants
from alphafold.model import geometry from alphafold.model import geometry
...@@ -276,7 +276,7 @@ def atom37_to_frames( ...@@ -276,7 +276,7 @@ def atom37_to_frames(
aatype: jnp.ndarray, # (...) aatype: jnp.ndarray, # (...)
all_atom_positions: geometry.Vec3Array, # (..., 37) all_atom_positions: geometry.Vec3Array, # (..., 37)
all_atom_mask: jnp.ndarray, # (..., 37) all_atom_mask: jnp.ndarray, # (..., 37)
) -> Dict[Text, jnp.ndarray]: ) -> Dict[str, jnp.ndarray]:
"""Computes the frames for the up to 8 rigid groups for each residue.""" """Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group', # 0: 'backbone group',
# 1: 'pre-omega-group', (empty) # 1: 'pre-omega-group', (empty)
...@@ -498,7 +498,7 @@ def between_residue_bond_loss( ...@@ -498,7 +498,7 @@ def between_residue_bond_loss(
residue_index: jnp.ndarray, # (N) residue_index: jnp.ndarray, # (N)
aatype: jnp.ndarray, # (N) aatype: jnp.ndarray, # (N)
tolerance_factor_soft=12.0, tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0) -> Dict[Text, jnp.ndarray]: tolerance_factor_hard=12.0) -> Dict[str, jnp.ndarray]:
"""Flat-bottom loss to penalize structural violations between residues.""" """Flat-bottom loss to penalize structural violations between residues."""
assert len(pred_atom_positions.shape) == 2 assert len(pred_atom_positions.shape) == 2
assert len(pred_atom_mask.shape) == 2 assert len(pred_atom_mask.shape) == 2
...@@ -600,7 +600,7 @@ def between_residue_clash_loss( ...@@ -600,7 +600,7 @@ def between_residue_clash_loss(
residue_index: jnp.ndarray, # (N) residue_index: jnp.ndarray, # (N)
asym_id: jnp.ndarray, # (N) asym_id: jnp.ndarray, # (N)
overlap_tolerance_soft=1.5, overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]: overlap_tolerance_hard=1.5) -> Dict[str, jnp.ndarray]:
"""Loss to penalize steric clashes between residues.""" """Loss to penalize steric clashes between residues."""
assert len(pred_positions.shape) == 2 assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2 assert len(atom_exists.shape) == 2
...@@ -682,7 +682,7 @@ def within_residue_violations( ...@@ -682,7 +682,7 @@ def within_residue_violations(
dists_lower_bound: jnp.ndarray, # (N, 14, 14) dists_lower_bound: jnp.ndarray, # (N, 14, 14)
dists_upper_bound: jnp.ndarray, # (N, 14, 14) dists_upper_bound: jnp.ndarray, # (N, 14, 14)
tighten_bounds_for_loss=0.0, tighten_bounds_for_loss=0.0,
) -> Dict[Text, jnp.ndarray]: ) -> Dict[str, jnp.ndarray]:
"""Find within-residue violations.""" """Find within-residue violations."""
assert len(pred_positions.shape) == 2 assert len(pred_positions.shape) == 2
assert len(atom_exists.shape) == 2 assert len(atom_exists.shape) == 2
...@@ -789,7 +789,7 @@ def frame_aligned_point_error( ...@@ -789,7 +789,7 @@ def frame_aligned_point_error(
pred_positions: geometry.Vec3Array, # shape (num_positions) pred_positions: geometry.Vec3Array, # shape (num_positions)
target_positions: geometry.Vec3Array, # shape (num_positions) target_positions: geometry.Vec3Array, # shape (num_positions)
positions_mask: jnp.ndarray, # shape (num_positions) positions_mask: jnp.ndarray, # shape (num_positions)
pair_mask: jnp.ndarray, # shape (num_frames, num_posiitons) pair_mask: Optional[jnp.ndarray], # shape (num_frames, num_posiitons)
l1_clamp_distance: float, l1_clamp_distance: float,
length_scale=20., length_scale=20.,
epsilon=1e-4) -> jnp.ndarray: # shape () epsilon=1e-4) -> jnp.ndarray: # shape ()
......
...@@ -66,7 +66,7 @@ class RunModel: ...@@ -66,7 +66,7 @@ class RunModel:
def __init__(self, def __init__(self,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): params: Optional[Mapping[str, Mapping[str, jax.Array]]] = None):
self.config = config self.config = config
self.params = params self.params = params
self.multimer_mode = config.model.global_config.multimer_mode self.multimer_mode = config.model.global_config.multimer_mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment