Unverified Commit 15dead11 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix for CPU offloading (#2403)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 41425476
......@@ -748,6 +748,11 @@ def get_cpu_offload_context(
double_buffering=double_buffering,
)
if not enabled:
if manual_synchronization:
return contextlib.nullcontext(), lambda x: x, None
return contextlib.nullcontext(), lambda x: x
if not offload_weights and not offload_activations:
raise ValueError(
"CPU Offloading is enabled while it is not "
......@@ -763,6 +768,8 @@ def get_cpu_offload_context(
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
if manual_synchronization:
return contextlib.nullcontext(), lambda x: x, None
return contextlib.nullcontext(), lambda x: x
if TEDebugState.debug_enabled:
......@@ -848,15 +855,13 @@ def get_cpu_offload_context(
cpu_offload_context = _CpuOffloadContext()
if enabled:
if manual_synchronization:
return (
cpu_offload_context,
cpu_offload_context.synchronization_function,
offload_synchronizer,
)
if manual_synchronization:
return (
cpu_offload_context,
cpu_offload_context.synchronization_function,
offload_synchronizer,
)
return contextlib.nullcontext(), lambda x: x
return (
cpu_offload_context,
cpu_offload_context.synchronization_function,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment