Unverified Commit 15dead11 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix for CPU offloading (#2403)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 41425476
...@@ -748,6 +748,11 @@ def get_cpu_offload_context( ...@@ -748,6 +748,11 @@ def get_cpu_offload_context(
double_buffering=double_buffering, double_buffering=double_buffering,
) )
if not enabled:
if manual_synchronization:
return contextlib.nullcontext(), lambda x: x, None
return contextlib.nullcontext(), lambda x: x
if not offload_weights and not offload_activations: if not offload_weights and not offload_activations:
raise ValueError( raise ValueError(
"CPU Offloading is enabled while it is not " "CPU Offloading is enabled while it is not "
...@@ -763,6 +768,8 @@ def get_cpu_offload_context( ...@@ -763,6 +768,8 @@ def get_cpu_offload_context(
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing. # Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations: if not offload_activations:
if manual_synchronization:
return contextlib.nullcontext(), lambda x: x, None
return contextlib.nullcontext(), lambda x: x return contextlib.nullcontext(), lambda x: x
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
...@@ -848,7 +855,6 @@ def get_cpu_offload_context( ...@@ -848,7 +855,6 @@ def get_cpu_offload_context(
cpu_offload_context = _CpuOffloadContext() cpu_offload_context = _CpuOffloadContext()
if enabled:
if manual_synchronization: if manual_synchronization:
return ( return (
cpu_offload_context, cpu_offload_context,
...@@ -859,4 +865,3 @@ def get_cpu_offload_context( ...@@ -859,4 +865,3 @@ def get_cpu_offload_context(
cpu_offload_context, cpu_offload_context,
cpu_offload_context.synchronization_function, cpu_offload_context.synchronization_function,
) )
return contextlib.nullcontext(), lambda x: x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment