Commit 929b188a authored by Tom Hennigan's avatar Tom Hennigan Committed by Copybara-Service
Browse files

Add split_rng=False (current default) to sharded_map.

Haiku plans to make split_rng a required argument to hk.vmap in an upcoming
release. This change updates AlphaFold to preserve the current behaviour. We
also handle the case where users are using a release of Haiku without the
split_rng option, for these users split_rng=False is implied.

PiperOrigin-RevId: 428429472
Change-Id: I3292396eb330ffabf2f36f5364aa0ae1bc74cf71
parent b5ed6b76
......@@ -15,6 +15,7 @@
"""Specialized mapping functions."""
import functools
import inspect
from typing import Any, Callable, Optional, Sequence, Union
......@@ -75,7 +76,11 @@ def sharded_map(
Returns:
function with smap applied.
"""
vmapped_fun = hk.vmap(fun, in_axes, out_axes)
if 'split_rng' in inspect.signature(hk.vmap).parameters:
vmapped_fun = hk.vmap(fun, in_axes, out_axes, split_rng=False)
else:
# TODO(tomhennigan): Remove this when older versions of Haiku aren't used.
vmapped_fun = hk.vmap(fun, in_axes, out_axes)
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment