Unverified Commit 61312d6a authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Deprecate the weight offloading (#1678)



* drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8ffbbabd
......@@ -81,7 +81,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
from .cpu_offload import set_offloading_param
from .cpu_offload import mark_activation_offload
# Setup Attention Logging
......@@ -4323,10 +4323,9 @@ class FlashAttention(torch.nn.Module):
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
for tensor in tensor_list:
if tensor is not None:
set_offloading_param(tensor, "activation_offloading", True)
mark_activation_offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
)
with self.attention_dropout_ctx():
# | API | use cases
......@@ -4729,13 +4728,8 @@ class FusedAttnFunc(torch.autograd.Function):
tensor_list = [q, k, v, out_save]
qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list:
if tensor is not None:
set_offloading_param(tensor, "activation_offloading", True)
for tensor in aux_ctx_tensors:
if tensor is not None:
set_offloading_param(tensor, "activation_offloading", True)
mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors)
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
......
......@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
def set_offloading_param(tensor, param_name, value):
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
assert param_name in ["weight_offloading", "activation_offloading"]
if tensor is None:
return
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
setattr(tensor, param_name, value)
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
setattr(tensor, param_name, value)
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool:
......@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, _ in self.tensor_tag_to_buf.items():
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
......@@ -538,7 +549,7 @@ def get_cpu_offload_context(
num_layers: int = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = True,
offload_weights: bool = False,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
......@@ -570,28 +581,30 @@ def get_cpu_offload_context(
"""
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
# This includes the Gradient Accumulation Buffer
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor):
return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading")
if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
if not offload_weights and not offload_activations:
raise ValueError(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
......
......@@ -63,7 +63,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
general_gemm,
......@@ -355,15 +355,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True)
set_offloading_param(ln_weight, "weight_offloading", True)
set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(inputmat, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(rsigma, "activation_offloading", True)
set_offloading_param(ln_out, "activation_offloading", True)
mark_activation_offload(inputmat, mu, rsigma, ln_out)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
......@@ -64,7 +64,7 @@ from ..tensor.float8_tensor import (
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
......@@ -473,23 +473,9 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if cpu_offloading:
if fp8 and fc1_weight_final is not None:
set_offloading_param(fc1_weight_final, "weight_offloading", True)
if fp8 and fc2_weight_final is not None:
set_offloading_param(fc2_weight_final, "weight_offloading", True)
set_offloading_param(ln_weight, "weight_offloading", True)
set_offloading_param(fc1_weight, "weight_offloading", True)
set_offloading_param(fc2_weight, "weight_offloading", True)
set_offloading_param(fc1_bias, "weight_offloading", True)
set_offloading_param(inputmat, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(rsigma, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(ln_out, "activation_offloading", True)
set_offloading_param(fc1_out, "activation_offloading", True)
set_offloading_param(fc1_out_without_bias, "activation_offloading", True)
set_offloading_param(act_out, "activation_offloading", True)
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
......@@ -62,7 +62,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
......@@ -307,11 +307,8 @@ class _Linear(torch.autograd.Function):
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(weightmat, "weight_offloading", True)
if saved_inputmat is not None:
set_offloading_param(saved_inputmat, "activation_offloading", True)
if cpu_offloading and saved_inputmat is not None:
mark_activation_offload(saved_inputmat)
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment