Unverified Commit c5257605 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Activation offloading refactor (#1762)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* offloading
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* all types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cpu offload + debug warning
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change empty_like implementation to use make_like
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* main_grad fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* manual synchornization
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* old path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* remove example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* reverted grouped linear
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* make odl code path work for modules
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* attention old code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* updated code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nvfp4 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update tests/pytorch/test_cpu_offloading.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* small fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* docs change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarroot <root@ptyche0312.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a0754757
...@@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml ...@@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
...@@ -50,6 +50,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel ...@@ -50,6 +50,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
) )
from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_activation_offload,
NVTE_CPU_OFFLOAD_V1,
)
from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded
# Import attention utils # Import attention utils
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
...@@ -737,6 +744,9 @@ class FlashAttention(torch.nn.Module): ...@@ -737,6 +744,9 @@ class FlashAttention(torch.nn.Module):
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
# get batch_size, max_seqlen and cu_seqlens # get batch_size, max_seqlen and cu_seqlens
batch_size, context_len = None, None batch_size, context_len = None, None
if inference_params is None: if inference_params is None:
...@@ -877,12 +887,7 @@ class FlashAttention(torch.nn.Module): ...@@ -877,12 +887,7 @@ class FlashAttention(torch.nn.Module):
fp8_output=fp8_output, fp8_output=fp8_output,
) )
else: else:
from transformer_engine.pytorch.cpu_offload import ( if is_cpu_offload_enabled():
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
mark_activation_offload( mark_activation_offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
) )
...@@ -1116,6 +1121,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1116,6 +1121,9 @@ class FusedAttnFunc(torch.autograd.Function):
nvtx_label = "transformer_engine.FusedAttnFunc.forward" nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}") nvtx_range_push(f"{nvtx_label}")
if is_cpu_offload_enabled():
start_offload(q, k, v, offload_base_tensor=True)
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"] # may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
...@@ -1293,12 +1301,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1293,12 +1301,7 @@ class FusedAttnFunc(torch.autograd.Function):
# used when some tensors are base tensors and loose the "dtype" attribute # used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import ( if is_cpu_offload_enabled() and NVTE_CPU_OFFLOAD_V1:
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
if ctx.fp8: if ctx.fp8:
tensor_list = fp8_tensors tensor_list = fp8_tensors
else: else:
...@@ -1309,6 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1309,6 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors, *fp8_tensors,
*qkvo_tensors, *qkvo_tensors,
...@@ -1339,27 +1343,26 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1339,27 +1343,26 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
from transformer_engine.pytorch.cpu_offload import ( if NVTE_CPU_OFFLOAD_V1:
CPUOffloadedLayer, # If interleaved tensor is offloaded, reloaded tensor will be
) # non-interleaved, so we need to modify the QKV layout
# for backward
# If interleaved tensor is offloaded, reloaded tensor will be if is_current_layer_offloaded() and is_cpu_offload_enabled():
# non-interleaved, so we need to modify the QKV layout reload_layout = ""
# for backward split_list = qkv_layout.split("_")
if CPUOffloadedLayer and CPUOffloadEnabled: for split in split_list:
reload_layout = "" temp_layout = ""
split_list = qkv_layout.split("_") rep_count = 1
for split in split_list: for s in split:
temp_layout = "" if s.isalpha():
rep_count = 1 temp_layout = temp_layout + s
for s in split: else:
if s.isalpha(): rep_count = int(s)
temp_layout = temp_layout + s for _ in range(rep_count):
else: reload_layout = reload_layout + temp_layout + "_"
rep_count = int(s) ctx.qkv_layout = reload_layout[:-1]
for _ in range(rep_count): else:
reload_layout = reload_layout + temp_layout + "_" ctx.qkv_layout = qkv_layout
ctx.qkv_layout = reload_layout[:-1]
else: else:
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
......
...@@ -1494,14 +1494,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1494,14 +1494,6 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_output=fp8_output, fp8_output=fp8_output,
) )
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if use_unfused_attention: if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention: if checkpoint_core_attention:
......
...@@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc ...@@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast(). # Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" # Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. # and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
...@@ -971,7 +973,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -971,7 +973,8 @@ class MultiheadAttention(torch.nn.Module):
# =========================== # ===========================
# Core attention computation # Core attention computation
# =========================== # ===========================
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, query_layer,
key_layer, key_layer,
......
This diff is collapsed.
This diff is collapsed.
...@@ -41,7 +41,7 @@ from ..cpp_extensions import ( ...@@ -41,7 +41,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..quantized_tensor import ( from ..quantized_tensor import (
...@@ -135,6 +135,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -135,6 +135,9 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
if cpu_offloading:
start_offload(*inputmats)
# Initialize weights # Initialize weights
weights_fp8: list weights_fp8: list
if fp8: if fp8:
...@@ -196,6 +199,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -196,6 +199,9 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(num_gemms): for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i]) weight_quantizers[i].calibrate(weights[i])
if cpu_offloading:
mark_not_offload(*weights_fp8, *weights)
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
......
...@@ -66,10 +66,15 @@ from ..quantized_tensor import ( ...@@ -66,10 +66,15 @@ from ..quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
...@@ -158,6 +163,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -158,6 +163,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
weight_requires_grad = weight.requires_grad weight_requires_grad = weight.requires_grad
...@@ -434,8 +442,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -434,8 +442,14 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading: if cpu_offloading:
mark_not_offload(
weightmat,
weight,
bias,
ln_weight,
ln_bias,
)
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will # If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper. # get back torch.Tensor. Torch rips off the Parameter wrapper.
...@@ -542,6 +556,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -542,6 +556,7 @@ class _LayerNormLinear(torch.autograd.Function):
mu, mu,
rsigma, rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors) ) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed # Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors. # by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None ctx.tensor_objects = None
......
...@@ -69,7 +69,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -69,7 +69,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
...@@ -235,6 +240,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -235,6 +240,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
...@@ -577,6 +584,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -577,6 +584,18 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
if cpu_offloading:
mark_not_offload(
ln_weight,
ln_bias,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc2_weight_final,
fc2_weight,
fc2_bias,
)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
inputmat, inputmat,
ln_weight, ln_weight,
......
...@@ -68,7 +68,12 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize ...@@ -68,7 +68,12 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_custom from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -229,6 +234,9 @@ class _Linear(torch.autograd.Function): ...@@ -229,6 +234,9 @@ class _Linear(torch.autograd.Function):
else: else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat inputmat_total = inputmat
if is_cpu_offload_enabled():
start_offload(inputmat)
nvtx_range_pop(f"{nvtx_label}.input_cast_comm") nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------ # ------------------------------------------------------
# Input tensor is ready for GEMM... # Input tensor is ready for GEMM...
...@@ -417,6 +425,7 @@ class _Linear(torch.autograd.Function): ...@@ -417,6 +425,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module # weights if weights are externally touched outside this module
ctx.weight_object = weight ctx.weight_object = weight
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
......
...@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer):
""" """
dtype = self.name_to_dtype_map[state_name] dtype = self.name_to_dtype_map[state_name]
if store_param_remainders: if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16) data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
else: else:
data = torch.empty_like(param, dtype=dtype) data = torch.empty(param.shape, dtype=dtype, device=param.device)
if zero_buffer: if zero_buffer:
data.zero_() data.zero_()
......
...@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union ...@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy import copy
import warnings import warnings
import math
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -20,6 +21,11 @@ from transformer_engine.pytorch.tensor._quantization_helpers import ( ...@@ -20,6 +21,11 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape, _stride_from_shape,
) )
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage: class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes. r"""Base class for all *TensorStorage classes.
...@@ -35,7 +41,7 @@ class QuantizedTensorStorage: ...@@ -35,7 +41,7 @@ class QuantizedTensorStorage:
XTensorStorage should contain all data members needed to XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while implement the functionality of the tensor, while
XTensor should only implement the functionality needed XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__).""" to behave like regular torch.Tensor (like __torch_dispatch__)."""
_quantizer: Optional[Quantizer] _quantizer: Optional[Quantizer]
...@@ -63,6 +69,12 @@ class QuantizedTensorStorage: ...@@ -63,6 +69,12 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement update_usage function" f"{self.__class__.__name__} class does not implement update_usage function"
) )
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
raise NotImplementedError( raise NotImplementedError(
...@@ -128,6 +140,7 @@ def prepare_for_saving( ...@@ -128,6 +140,7 @@ def prepare_for_saving(
t, t_obj = tensor.prepare_for_saving() t, t_obj = tensor.prepare_for_saving()
tensor_list.extend(t) tensor_list.extend(t)
tensor_objects_list.append(t_obj) tensor_objects_list.append(t_obj)
return tensor_list, tensor_objects_list return tensor_list, tensor_objects_list
...@@ -314,6 +327,13 @@ class Quantizer(abc.ABC): ...@@ -314,6 +327,13 @@ class Quantizer(abc.ABC):
"""Returns whether or not given tensor can be quantized""" """Returns whether or not given tensor can be quantized"""
return True return True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the quantizer"""
return {
"rowwise": self.rowwise_usage,
"columnwise": self.columnwise_usage,
}
class QuantizedTensor(torch.Tensor): class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data """Abstract base class for tensor with quantized data
...@@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor): ...@@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor):
""" """
def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): def __new__(
cls,
shape: Iterable[int],
dtype: torch.dtype,
*,
requires_grad: bool = False,
device: Optional[torch.device] = None,
):
# We are assuming only contiguous tensors # We are assuming only contiguous tensors
stride = _stride_from_shape(shape) stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass( instance = torch.Tensor._make_wrapper_subclass(
...@@ -336,7 +363,7 @@ class QuantizedTensor(torch.Tensor): ...@@ -336,7 +363,7 @@ class QuantizedTensor(torch.Tensor):
dtype=dtype, dtype=dtype,
layout=torch.strided, layout=torch.strided,
requires_grad=requires_grad, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device() if device is None else device,
) )
return instance return instance
...@@ -366,6 +393,9 @@ class QuantizedTensor(torch.Tensor): ...@@ -366,6 +393,9 @@ class QuantizedTensor(torch.Tensor):
def clear(self): def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully""" """Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement clear function"
)
def __repr__(self, *, tensor_contents=None) -> str: def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
...@@ -407,6 +437,26 @@ class QuantizedTensor(torch.Tensor): ...@@ -407,6 +437,26 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst = args[0] dst = args[0]
src = args[1] src = args[1]
if (
isinstance(dst, QuantizedTensor)
and isinstance(src, QuantizedTensor)
and type(dst._quantizer) is type(src._quantizer)
and set(src.get_usages().keys()) == set(dst.get_usages().keys())
and all(
src.get_usages()[usage] == dst.get_usages()[usage]
for usage in src.get_usages().keys()
)
):
dst_tensors, dst_tensor_obj = dst.prepare_for_saving()
src_tensors, src_tensor_obj = src.prepare_for_saving()
for dst_tensor, src_tensor in zip(dst_tensors, src_tensors):
if dst_tensor is not None:
dst_tensor.copy_(src_tensor, *args[2:], **kwargs)
dst_tensor_obj.restore_from_saved(dst_tensors)
src_tensor_obj.restore_from_saved(src_tensors)
return None
if isinstance(dst, QuantizedTensor): if isinstance(dst, QuantizedTensor):
dst.quantize_(src) dst.quantize_(src)
else: else:
...@@ -419,6 +469,36 @@ class QuantizedTensor(torch.Tensor): ...@@ -419,6 +469,36 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.view.default: if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views") raise NotImplementedError("{cls.__name__} class does not support tensor views")
# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
device = kwargs.get("device", tensor.device)
requires_grad = kwargs.get("requires_grad", tensor.requires_grad)
pin_memory = kwargs.get("pin_memory", False)
usage = tensor.get_usages()
quantizer_usage = tensor._quantizer.get_usages()
tensor._quantizer.set_usage(**usage)
out = tensor._quantizer.make_empty(
shape=tensor.shape,
dtype=tensor.dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
tensor._quantizer.set_usage(**quantizer_usage)
return out
if func == torch.ops.aten.numel.default:
tensor = args[0]
return math.prod(tensor.size())
if func == torch.ops.aten.is_pinned.default:
tensor = args[0]
for t in tensor.get_data_tensors():
if t is not None:
return func(t)
return False # Or error out?
def maybe_unwrap(arg): def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor): if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype) return arg.dequantize(dtype=arg.dtype)
...@@ -463,6 +543,16 @@ class QuantizedTensor(torch.Tensor): ...@@ -463,6 +543,16 @@ class QuantizedTensor(torch.Tensor):
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor # Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs) return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......
...@@ -214,6 +214,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -214,6 +214,7 @@ class Float8BlockQuantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data""" """Construct quantized tensor with uninitialized data"""
if device is None: if device is None:
...@@ -229,12 +230,13 @@ class Float8BlockQuantizer(Quantizer): ...@@ -229,12 +230,13 @@ class Float8BlockQuantizer(Quantizer):
data = None data = None
scale_inv = None scale_inv = None
if self.rowwise_usage: if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device) data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False) scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty( scale_inv = torch.empty(
scale_shape, scale_shape,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
pin_memory=pin_memory,
) )
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
...@@ -242,13 +244,17 @@ class Float8BlockQuantizer(Quantizer): ...@@ -242,13 +244,17 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty( columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, columnwise_scale_shape,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
......
...@@ -101,6 +101,7 @@ class Float8Quantizer(Quantizer): ...@@ -101,6 +101,7 @@ class Float8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor: ) -> Float8Tensor:
# Canonicalize tensor attributes # Canonicalize tensor attributes
...@@ -108,16 +109,19 @@ class Float8Quantizer(Quantizer): ...@@ -108,16 +109,19 @@ class Float8Quantizer(Quantizer):
device = torch.device("cuda") device = torch.device("cuda")
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
data_transpose = None data_transpose = None
if self.columnwise_usage: if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1]) transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty( data_transpose = torch.empty(
transpose_shape, transpose_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
...@@ -125,7 +129,7 @@ class Float8Quantizer(Quantizer): ...@@ -125,7 +129,7 @@ class Float8Quantizer(Quantizer):
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
data=data, data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
data_transpose=data_transpose, data_transpose=data_transpose,
...@@ -287,6 +291,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -287,6 +291,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor: ) -> Float8Tensor:
# Canonicalize tensor attributes # Canonicalize tensor attributes
...@@ -294,23 +299,26 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -294,23 +299,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
device = torch.device("cuda") device = torch.device("cuda")
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
data_transpose = None data_transpose = None
if self.columnwise_usage: if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1]) transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty( data_transpose = torch.empty(
transpose_shape, transpose_shape,
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
return Float8Tensor( return Float8Tensor(
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
data=data, data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
data_transpose=data_transpose, data_transpose=data_transpose,
...@@ -715,14 +723,22 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -715,14 +723,22 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
return cls.detach(args[0]) return cls.detach(args[0])
if func == torch.ops.aten.clone.default: if func == torch.ops.aten.clone.default:
return cls.clone(args[0]) return cls.clone(args[0])
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1] dst, src = args[0], args[1]
# Just copy FP8 attrs if copying between Float8Tensors # Just copy FP8 attrs if copying between Float8Tensors
if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor):
dst._data.copy_(src._data.detach()) if dst._data is not None:
dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) dst._data.copy_(src._data.detach(), *args[2:], **kwargs)
if src._transpose is not None or dst._transpose is not None: if dst._scale_inv is not None:
dst._create_transpose() dst._scale_inv.copy_(
src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs
)
if dst._transpose is not None and not dst._transpose_invalid:
if not src._transpose_invalid:
dst._transpose.copy_(src._transpose, *args[2:], **kwargs)
else:
dst._create_transpose()
return dst return dst
elif func in _ops_to_preserve_subclass_in_fsdp2: elif func in _ops_to_preserve_subclass_in_fsdp2:
# Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
......
...@@ -90,6 +90,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -90,6 +90,7 @@ class MXFP8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = False, requires_grad: bool = False,
pin_memory: bool = False,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
# Canonicalize tensor attributes # Canonicalize tensor attributes
...@@ -105,24 +106,29 @@ class MXFP8Quantizer(Quantizer): ...@@ -105,24 +106,29 @@ class MXFP8Quantizer(Quantizer):
) )
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
scale_inv = torch.empty( scale_inv = None
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), if self.rowwise_usage:
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
dtype=torch.uint8, scale_inv = torch.empty(
device=device, round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
) round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
columnwise_data = None columnwise_data = None
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty_like(data) columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128), round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
...@@ -348,11 +354,17 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -348,11 +354,17 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
) )
if rowwise_matches and columnwise_matches: if rowwise_matches and columnwise_matches:
if dst._rowwise_data is not None: if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach()) dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) dst._rowwise_scale_inv.copy_(
src._rowwise_scale_inv.detach(), *args[2:], **kwargs
)
if dst._columnwise_data is not None: if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach()) dst._columnwise_data.copy_(
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) src._columnwise_data.detach(), *args[2:], **kwargs
)
dst._columnwise_scale_inv.copy_(
src._columnwise_scale_inv.detach(), *args[2:], **kwargs
)
return dst return dst
# FSDP2 related functions. # FSDP2 related functions.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import math import math
from typing import Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import functools import functools
import torch import torch
...@@ -265,6 +265,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -265,6 +265,7 @@ class NVFP4Quantizer(Quantizer):
*, *,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False, requires_grad: bool = False,
) -> NVFP4Tensor: ) -> NVFP4Tensor:
...@@ -288,11 +289,18 @@ class NVFP4Quantizer(Quantizer): ...@@ -288,11 +289,18 @@ class NVFP4Quantizer(Quantizer):
scale_inv = None scale_inv = None
amax_rowwise = None amax_rowwise = None
if self.rowwise_usage: if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device) data = torch.empty(
self.convert_shape_for_fp4(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
scale_shape = self.get_scale_shape(shape, columnwise=False) scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device) scale_inv = torch.empty(
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
# Allocate per tensor scale inverse. FP32 format. # Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device) amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
columnwise_data = None columnwise_data = None
...@@ -306,12 +314,15 @@ class NVFP4Quantizer(Quantizer): ...@@ -306,12 +314,15 @@ class NVFP4Quantizer(Quantizer):
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
pin_memory=pin_memory,
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
) )
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor # Construct FP8 tensor
return NVFP4Tensor( return NVFP4Tensor(
...@@ -498,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -498,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
return self return self
raise ValueError("NVFP4Tensor does not support different memory formats!") raise ValueError("NVFP4Tensor does not support different memory formats!")
def get_usages(self) -> Dict[str, bool]:
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
...@@ -520,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -520,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
) )
if tensor._rowwise_data is not None: if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data) rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv) rowwise_scale_inv = scale_inv_init_func(
amax_rowwise = torch.zeros_like(tensor._amax_rowwise) tensor._rowwise_scale_inv, *args[1:], **kwargs
)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs)
else: else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None: if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data) columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv) columnwise_scale_inv = scale_inv_init_func(
amax_columnwise = torch.zeros_like(tensor._amax_columnwise) tensor._columnwise_scale_inv, *args[1:], **kwargs
)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs)
else: else:
columnwise_data, columnwise_scale_inv, amax_columnwise = ( columnwise_data, columnwise_scale_inv, amax_columnwise = (
None, None,
......
...@@ -420,3 +420,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -420,3 +420,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return return
return return
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
...@@ -225,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage): ...@@ -225,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage):
if not needs_data_transpose: if not needs_data_transpose:
self._transpose = None self._transpose = None
self._transpose_invalid = True self._transpose_invalid = True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
usages = {"rowwise": self._data is not None}
if is_non_tn_fp8_gemm_supported():
usages["columnwise"] = self._data is not None
else:
usages["columnwise"] = self._transpose is not None and not self._transpose_invalid
return usages
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment