Unverified Commit 61312d6a authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Deprecate the weight offloading (#1678)



* drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8ffbbabd
...@@ -81,7 +81,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils ...@@ -81,7 +81,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
from .cpu_offload import set_offloading_param from .cpu_offload import mark_activation_offload
# Setup Attention Logging # Setup Attention Logging
...@@ -4323,10 +4323,9 @@ class FlashAttention(torch.nn.Module): ...@@ -4323,10 +4323,9 @@ class FlashAttention(torch.nn.Module):
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled: if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] mark_activation_offload(
for tensor in tensor_list: query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
if tensor is not None: )
set_offloading_param(tensor, "activation_offloading", True)
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
# | API | use cases # | API | use cases
...@@ -4729,13 +4728,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4729,13 +4728,8 @@ class FusedAttnFunc(torch.autograd.Function):
tensor_list = [q, k, v, out_save] tensor_list = [q, k, v, out_save]
qkv_layout = "sbhd_sbhd_sbhd" qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list: mark_activation_offload(*tensor_list)
if tensor is not None: mark_activation_offload(*aux_ctx_tensors)
set_offloading_param(tensor, "activation_offloading", True)
for tensor in aux_ctx_tensors:
if tensor is not None:
set_offloading_param(tensor, "activation_offloading", True)
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
......
...@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"] ...@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False CPUOffloadEnabled = False
def set_offloading_param(tensor, param_name, value): def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor.""" """Set the type of the offloading needed for a tensor."""
assert param_name in ["weight_offloading", "activation_offloading"] for tensor in tensors:
if tensor is None: if tensor is None:
return continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]: if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
setattr(tensor, param_name, value) tensor.activation_offloading = True
else: else:
data_tensors = tensor.get_data_tensors() data_tensors = tensor.get_data_tensors()
for tensor in data_tensors: for tensor in data_tensors:
if tensor is not None: if tensor is not None:
setattr(tensor, param_name, value) tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool: def is_cpu_offload_enabled() -> bool:
...@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
torch.cuda.current_stream().wait_stream(self.d2h_stream) torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage # Time to free the activation memory after usage
for tensor_tag, _ in self.tensor_tag_to_buf.items(): for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count: if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group # Time to offload the next group
...@@ -538,7 +549,7 @@ def get_cpu_offload_context( ...@@ -538,7 +549,7 @@ def get_cpu_offload_context(
num_layers: int = 1, num_layers: int = 1,
model_layers: int = 1, model_layers: int = 1,
offload_activations: bool = True, offload_activations: bool = True,
offload_weights: bool = True, offload_weights: bool = False,
): ):
""" """
This function returns the CPU Offload context and the synchronizer function that needs to be This function returns the CPU Offload context and the synchronizer function that needs to be
...@@ -570,28 +581,30 @@ def get_cpu_offload_context( ...@@ -570,28 +581,30 @@ def get_cpu_offload_context(
""" """
def tensor_need_offloading_checker_activations(tensor): if not offload_weights and not offload_activations:
return hasattr(tensor, "activation_offloading")
# This includes the Gradient Accumulation Buffer
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor):
return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading")
if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
raise ValueError( raise ValueError(
"CPU Offloading is enabled while it is not " "CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)" "mentioned what to offload (weights/activations)"
) )
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers, num_offload_group=num_layers,
num_model_group=model_layers, num_model_group=model_layers,
......
...@@ -63,7 +63,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize ...@@ -63,7 +63,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
...@@ -355,15 +355,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -355,15 +355,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
if fp8 and weightmat is not None: mark_activation_offload(inputmat, mu, rsigma, ln_out)
set_offloading_param(weightmat, "weight_offloading", True)
set_offloading_param(ln_weight, "weight_offloading", True)
set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(inputmat, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(rsigma, "activation_offloading", True)
set_offloading_param(ln_out, "activation_offloading", True)
# Scatter intermediate/activation tensors saved for the backward pass # Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
...@@ -64,7 +64,7 @@ from ..tensor.float8_tensor import ( ...@@ -64,7 +64,7 @@ from ..tensor.float8_tensor import (
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
...@@ -473,23 +473,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -473,23 +473,9 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else: else:
if cpu_offloading: if cpu_offloading:
if fp8 and fc1_weight_final is not None: mark_activation_offload(
set_offloading_param(fc1_weight_final, "weight_offloading", True) inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
if fp8 and fc2_weight_final is not None: )
set_offloading_param(fc2_weight_final, "weight_offloading", True)
set_offloading_param(ln_weight, "weight_offloading", True)
set_offloading_param(fc1_weight, "weight_offloading", True)
set_offloading_param(fc2_weight, "weight_offloading", True)
set_offloading_param(fc1_bias, "weight_offloading", True)
set_offloading_param(inputmat, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(rsigma, "activation_offloading", True)
set_offloading_param(mu, "activation_offloading", True)
set_offloading_param(ln_out, "activation_offloading", True)
set_offloading_param(fc1_out, "activation_offloading", True)
set_offloading_param(fc1_out_without_bias, "activation_offloading", True)
set_offloading_param(act_out, "activation_offloading", True)
# Scatter intermediate/activation tensors saved for the backward pass # Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
......
...@@ -62,7 +62,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize ...@@ -62,7 +62,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.utils import any_feature_enabled
...@@ -307,11 +307,8 @@ class _Linear(torch.autograd.Function): ...@@ -307,11 +307,8 @@ class _Linear(torch.autograd.Function):
if isinstance(weightmat, QuantizedTensor): if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading and saved_inputmat is not None:
set_offloading_param(weight, "weight_offloading", True) mark_activation_offload(saved_inputmat)
set_offloading_param(weightmat, "weight_offloading", True)
if saved_inputmat is not None:
set_offloading_param(saved_inputmat, "activation_offloading", True)
# Scatter intermediate/activation tensors saved for the backward pass # Scatter intermediate/activation tensors saved for the backward pass
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment