Unverified Commit ce328fac authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix backward compatibility for checkpoint API (#748)



* Args can be None
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix other arg types
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e1e2b76e
......@@ -516,12 +516,6 @@ def checkpoint(
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
only_tensor_args = True
for arg in args:
if not isinstance(arg, torch.Tensor):
only_tensor_args = False
break
# Pop out te.distributed.checkpoint() arguments
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
......@@ -530,23 +524,14 @@ def checkpoint(
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)
# Ensure backward compatibility.
if not only_tensor_args:
if (len(args) > 3 and isinstance(args[0], bool) and callable(args[1])
and isinstance(args[2], None | dist_group_type)):
warnings.warn(
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
DeprecationWarning, stacklevel=2,
)
assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API."
assert (
isinstance(args[0], bool) and callable(args[1])
and isinstance(args[2], None | dist_group_type)
), "Incorrect arguments for deprecated `checkpoint` API."
for arg in args[3:]:
assert (
isinstance(arg, None | torch.Tensor)
), f"Expected tensor argument, found {type(arg)}."
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
args = args[3:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment