Commit d9e5e1d9 authored by Augustin Zidek's avatar Augustin Zidek Committed by Copybara-Service
Browse files

Fix jax.tree_multimap deprecation warning.

PiperOrigin-RevId: 451994826
Change-Id: I4573baf61d33010c75de717d3b49f47bc9c6a8ac
parent 197bd19e
...@@ -426,7 +426,7 @@ def torsion_angles_to_frames( ...@@ -426,7 +426,7 @@ def torsion_angles_to_frames(
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6] chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7] chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]
all_frames_to_backb = jax.tree_multimap( all_frames_to_backb = jax.tree_map(
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5], lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None], chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
chi4_frame_to_backb[:, None]) chi4_frame_to_backb[:, None])
......
...@@ -546,7 +546,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], ...@@ -546,7 +546,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
) )
outputs.append(output) outputs.append(output)
output = jax.tree_multimap(lambda *x: jnp.stack(x), *outputs) output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
# Pass along for LDDT-Head. # Pass along for LDDT-Head.
output['act'] = activations['act'] output['act'] = activations['act']
...@@ -823,7 +823,7 @@ def compute_frames( ...@@ -823,7 +823,7 @@ def compute_frames(
alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames'] alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames']
use_alt = use_alt[:, None] use_alt = use_alt[:, None]
renamed_gt_frames = jax.tree_multimap( renamed_gt_frames = jax.tree_map(
lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames) lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames)
return renamed_gt_frames, frames_batch['rigidgroups_gt_exists'] return renamed_gt_frames, frames_batch['rigidgroups_gt_exists']
...@@ -1160,4 +1160,3 @@ class MultiRigidSidechain(hk.Module): ...@@ -1160,4 +1160,3 @@ class MultiRigidSidechain(hk.Module):
'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8) 'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8)
}) })
return outputs return outputs
...@@ -53,10 +53,10 @@ class Vec3Array: ...@@ -53,10 +53,10 @@ class Vec3Array:
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
def __add__(self, other: Vec3Array) -> Vec3Array: def __add__(self, other: Vec3Array) -> Vec3Array:
return jax.tree_multimap(lambda x, y: x + y, self, other) return jax.tree_map(lambda x, y: x + y, self, other)
def __sub__(self, other: Vec3Array) -> Vec3Array: def __sub__(self, other: Vec3Array) -> Vec3Array:
return jax.tree_multimap(lambda x, y: x - y, self, other) return jax.tree_map(lambda x, y: x - y, self, other)
def __mul__(self, other: Float) -> Vec3Array: def __mul__(self, other: Float) -> Vec3Array:
return jax.tree_map(lambda x: x * other, self) return jax.tree_map(lambda x: x * other, self)
......
...@@ -198,8 +198,8 @@ class LayerStackTest(parameterized.TestCase): ...@@ -198,8 +198,8 @@ class LayerStackTest(parameterized.TestCase):
assert_fn = functools.partial( assert_fn = functools.partial(
np.testing.assert_allclose, atol=1e-4, rtol=1e-4) np.testing.assert_allclose, atol=1e-4, rtol=1e-4)
jax.tree_multimap(assert_fn, unrolled_grad, jax.tree_map(assert_fn, unrolled_grad,
_slice_layers_params(layer_stack_grad)) _slice_layers_params(layer_stack_grad))
def test_random(self): def test_random(self):
"""Random numbers should be handled correctly.""" """Random numbers should be handled correctly."""
......
...@@ -125,7 +125,7 @@ def sharded_apply( ...@@ -125,7 +125,7 @@ def sharded_apply(
# Expand in axes and Determine Loop range # Expand in axes and Determine Loop range
in_axes_ = _expand_axes(in_axes, args) in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_) in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0] flat_sizes = jax.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes) in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes) assert all(i in {in_size, -1} for i in flat_sizes)
...@@ -137,7 +137,7 @@ def sharded_apply( ...@@ -137,7 +137,7 @@ def sharded_apply(
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
def apply_fun_to_slice(slice_start, slice_size): def apply_fun_to_slice(slice_start, slice_size):
input_slice = jax.tree_multimap( input_slice = jax.tree_map(
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
), args, in_axes_) ), args, in_axes_)
return fun(*input_slice) return fun(*input_slice)
...@@ -158,11 +158,11 @@ def sharded_apply( ...@@ -158,11 +158,11 @@ def sharded_apply(
shard_shape[axis] * num_extra_shards + shard_shape[axis] * num_extra_shards +
remainder_shape[axis],) + shard_shape[axis + 1:] remainder_shape[axis],) + shard_shape[axis + 1:]
out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes, out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
out_shapes) out_shapes)
# Calls dynamic Update slice with different argument order # Calls dynamic Update slice with different argument order
# This is here since tree_multimap only works with positional arguments # This is here since tree_map only works with positional arguments
def dynamic_update_slice_in_dim(full_array, update, axis, i): def dynamic_update_slice_in_dim(full_array, update, axis, i):
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
...@@ -170,7 +170,7 @@ def sharded_apply( ...@@ -170,7 +170,7 @@ def sharded_apply(
slice_out = apply_fun_to_slice(slice_start, slice_size) slice_out = apply_fun_to_slice(slice_start, slice_size)
update_slice = partial( update_slice = partial(
dynamic_update_slice_in_dim, i=slice_start) dynamic_update_slice_in_dim, i=slice_start)
return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_) return jax.tree_map(update_slice, outputs, slice_out, out_axes_)
def scan_iteration(outputs, i): def scan_iteration(outputs, i):
new_outputs = compute_shard(outputs, i, shard_size) new_outputs = compute_shard(outputs, i, shard_size)
...@@ -181,7 +181,7 @@ def sharded_apply( ...@@ -181,7 +181,7 @@ def sharded_apply(
def allocate_buffer(dtype, shape): def allocate_buffer(dtype, shape):
return jnp.zeros(shape, dtype=dtype) return jnp.zeros(shape, dtype=dtype)
outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes) outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)
if slice_starts.shape[0] > 0: if slice_starts.shape[0] > 0:
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment