Commit 98caef21 authored by Jake VanderPlas's avatar Jake VanderPlas Committed by Copybara-Service
Browse files

Ensure values passed to jax.numpy functions are arrays rather than lists.

Why? This will soon be a requirement in JAX; see https://github.com/google/jax/issues/7737

PiperOrigin-RevId: 394105499
Change-Id: I6f76e3f47b7906616c1f988be58d29c828414060
parent 495d81ac
...@@ -750,10 +750,10 @@ def find_structural_violations( ...@@ -750,10 +750,10 @@ def find_structural_violations(
# Compute the Van der Waals radius for every atom # Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type). # (the first letter of the atom name is the element type).
# Shape: (N, 14). # Shape: (N, 14).
atomtype_radius = [ atomtype_radius = jnp.array([
residue_constants.van_der_waals_radius[name[0]] residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types for name in residue_constants.atom_types
] ])
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather( atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
atomtype_radius, batch['residx_atom14_to_atom37']) atomtype_radius, batch['residx_atom14_to_atom37'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment