"torchvision/vscode:/vscode.git/clone" did not exist on "ec203153095ad3d2e79fbf2865d80fe6076618fa"
Commit 98caef21 authored by Jake VanderPlas's avatar Jake VanderPlas Committed by Copybara-Service
Browse files

Ensure values passed to jax.numpy functions are arrays rather than lists.

Why? This will soon be a requirement in JAX; see https://github.com/google/jax/issues/7737

PiperOrigin-RevId: 394105499
Change-Id: I6f76e3f47b7906616c1f988be58d29c828414060
parent 495d81ac
......@@ -750,10 +750,10 @@ def find_structural_violations(
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius = [
atomtype_radius = jnp.array([
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
])
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
atomtype_radius, batch['residx_atom14_to_atom37'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment