Unverified Commit c5257605 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Activation offloading refactor (#1762)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* offloading
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* all types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cpu offload + debug warning
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change empty_like implementation to use make_like
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* main_grad fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* manual synchornization
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* old path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* remove example
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* api changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* reverted grouped linear
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* make odl code path work for modules
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* attention old code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* legacy tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* updated code path
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nvfp4 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update tests/pytorch/test_cpu_offloading.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* small fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* docs change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarroot <root@ptyche0312.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a0754757
......@@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import contextlib
import gc
import os
from typing import Iterable, Optional
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check supported quantization schemes
fp8_available = te.is_fp8_available()
mxfp8_available = te.is_mxfp8_available()
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert os.getenv("NVTE_FLASH_ATTN") == "0"
# CPU offload v1 code path is enabled
assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1"
# Offloading is supported for attention only for fused and flash attention backends,
# so the use of bfloat16 is required.
#
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types = {
"linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16),
"layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16),
"multihead_attention": lambda: te.MultiheadAttention(
SIZE, NUM_HEADS, params_dtype=torch.bfloat16
),
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.autocast(
enabled=quantization_recipe is not None, recipe=quantization_recipe
), offload_context:
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
# Warmup
_warmup_model(modules_list, quantization_recipe)
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
......@@ -50,6 +50,13 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
)
from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_activation_offload,
NVTE_CPU_OFFLOAD_V1,
)
from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded
# Import attention utils
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
......@@ -737,6 +744,9 @@ class FlashAttention(torch.nn.Module):
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
# get batch_size, max_seqlen and cu_seqlens
batch_size, context_len = None, None
if inference_params is None:
......@@ -877,12 +887,7 @@ class FlashAttention(torch.nn.Module):
fp8_output=fp8_output,
)
else:
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
if is_cpu_offload_enabled():
mark_activation_offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
)
......@@ -1116,6 +1121,9 @@ class FusedAttnFunc(torch.autograd.Function):
nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}")
if is_cpu_offload_enabled():
start_offload(q, k, v, offload_base_tensor=True)
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
......@@ -1293,12 +1301,7 @@ class FusedAttnFunc(torch.autograd.Function):
# used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
mark_activation_offload,
)
if CPUOffloadEnabled:
if is_cpu_offload_enabled() and NVTE_CPU_OFFLOAD_V1:
if ctx.fp8:
tensor_list = fp8_tensors
else:
......@@ -1309,6 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors,
*qkvo_tensors,
......@@ -1339,27 +1343,26 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadedLayer,
)
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if CPUOffloadedLayer and CPUOffloadEnabled:
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
if NVTE_CPU_OFFLOAD_V1:
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if is_current_layer_offloaded() and is_cpu_offload_enabled():
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
else:
ctx.qkv_layout = qkv_layout
else:
ctx.qkv_layout = qkv_layout
......
......@@ -1494,14 +1494,6 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_output=fp8_output,
)
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention:
......
......@@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProduc
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
......@@ -971,7 +973,8 @@ class MultiheadAttention(torch.nn.Module):
# ===========================
# Core attention computation
# ===========================
if is_cpu_offload_enabled():
start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True)
context_layer = self.core_attention(
query_layer,
key_layer,
......
This diff is collapsed.
This diff is collapsed.
......@@ -41,7 +41,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..quantized_tensor import (
......@@ -135,6 +135,9 @@ class _GroupedLinear(torch.autograd.Function):
else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
if cpu_offloading:
start_offload(*inputmats)
# Initialize weights
weights_fp8: list
if fp8:
......@@ -196,6 +199,9 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i])
if cpu_offloading:
mark_not_offload(*weights_fp8, *weights)
if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1]
......
......@@ -66,10 +66,15 @@ from ..quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
general_gemm,
......@@ -158,6 +163,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group)
weight_requires_grad = weight.requires_grad
......@@ -434,8 +442,14 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
if cpu_offloading:
mark_not_offload(
weightmat,
weight,
bias,
ln_weight,
ln_bias,
)
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
......@@ -542,6 +556,7 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......
......@@ -69,7 +69,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
......@@ -235,6 +240,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
......@@ -577,6 +584,18 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out)
act_out = None
if cpu_offloading:
mark_not_offload(
ln_weight,
ln_bias,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc2_weight_final,
fc2_weight,
fc2_bias,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight,
......
......@@ -68,7 +68,12 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["Linear"]
......@@ -229,6 +234,9 @@ class _Linear(torch.autograd.Function):
else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat
if is_cpu_offload_enabled():
start_offload(inputmat)
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
......@@ -417,6 +425,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx.weight_object = weight
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......
......@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer):
"""
dtype = self.name_to_dtype_map[state_name]
if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16)
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
else:
data = torch.empty_like(param, dtype=dtype)
data = torch.empty(param.shape, dtype=dtype, device=param.device)
if zero_buffer:
data.zero_()
......
......@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import math
import torch
from torch.utils._pytree import tree_map
......@@ -20,6 +21,11 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape,
)
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes.
......@@ -35,7 +41,7 @@ class QuantizedTensorStorage:
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
to behave like regular torch.Tensor (like __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
......@@ -63,6 +69,12 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement update_usage function"
)
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
......@@ -128,6 +140,7 @@ def prepare_for_saving(
t, t_obj = tensor.prepare_for_saving()
tensor_list.extend(t)
tensor_objects_list.append(t_obj)
return tensor_list, tensor_objects_list
......@@ -314,6 +327,13 @@ class Quantizer(abc.ABC):
"""Returns whether or not given tensor can be quantized"""
return True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the quantizer"""
return {
"rowwise": self.rowwise_usage,
"columnwise": self.columnwise_usage,
}
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
......@@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor):
"""
def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False):
def __new__(
cls,
shape: Iterable[int],
dtype: torch.dtype,
*,
requires_grad: bool = False,
device: Optional[torch.device] = None,
):
# We are assuming only contiguous tensors
stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass(
......@@ -336,7 +363,7 @@ class QuantizedTensor(torch.Tensor):
dtype=dtype,
layout=torch.strided,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
device=torch.cuda.current_device() if device is None else device,
)
return instance
......@@ -366,6 +393,9 @@ class QuantizedTensor(torch.Tensor):
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement clear function"
)
def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
......@@ -407,6 +437,26 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if (
isinstance(dst, QuantizedTensor)
and isinstance(src, QuantizedTensor)
and type(dst._quantizer) is type(src._quantizer)
and set(src.get_usages().keys()) == set(dst.get_usages().keys())
and all(
src.get_usages()[usage] == dst.get_usages()[usage]
for usage in src.get_usages().keys()
)
):
dst_tensors, dst_tensor_obj = dst.prepare_for_saving()
src_tensors, src_tensor_obj = src.prepare_for_saving()
for dst_tensor, src_tensor in zip(dst_tensors, src_tensors):
if dst_tensor is not None:
dst_tensor.copy_(src_tensor, *args[2:], **kwargs)
dst_tensor_obj.restore_from_saved(dst_tensors)
src_tensor_obj.restore_from_saved(src_tensors)
return None
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
......@@ -419,6 +469,36 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
device = kwargs.get("device", tensor.device)
requires_grad = kwargs.get("requires_grad", tensor.requires_grad)
pin_memory = kwargs.get("pin_memory", False)
usage = tensor.get_usages()
quantizer_usage = tensor._quantizer.get_usages()
tensor._quantizer.set_usage(**usage)
out = tensor._quantizer.make_empty(
shape=tensor.shape,
dtype=tensor.dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
tensor._quantizer.set_usage(**quantizer_usage)
return out
if func == torch.ops.aten.numel.default:
tensor = args[0]
return math.prod(tensor.size())
if func == torch.ops.aten.is_pinned.default:
tensor = args[0]
for t in tensor.get_data_tensors():
if t is not None:
return func(t)
return False # Or error out?
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
......@@ -463,6 +543,16 @@ class QuantizedTensor(torch.Tensor):
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......
......@@ -214,6 +214,7 @@ class Float8BlockQuantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
......@@ -229,12 +230,13 @@ class Float8BlockQuantizer(Quantizer):
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed
......@@ -242,13 +244,17 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device
self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......
......@@ -101,6 +101,7 @@ class Float8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
......@@ -108,16 +109,19 @@ class Float8Quantizer(Quantizer):
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty(
transpose_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......@@ -125,7 +129,7 @@ class Float8Quantizer(Quantizer):
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
......@@ -287,6 +291,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
......@@ -294,23 +299,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty(
transpose_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
return Float8Tensor(
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
......@@ -715,14 +723,22 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
return cls.clone(args[0])
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
# Just copy FP8 attrs if copying between Float8Tensors
if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor):
dst._data.copy_(src._data.detach())
dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size()))
if src._transpose is not None or dst._transpose is not None:
dst._create_transpose()
if dst._data is not None:
dst._data.copy_(src._data.detach(), *args[2:], **kwargs)
if dst._scale_inv is not None:
dst._scale_inv.copy_(
src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs
)
if dst._transpose is not None and not dst._transpose_invalid:
if not src._transpose_invalid:
dst._transpose.copy_(src._transpose, *args[2:], **kwargs)
else:
dst._create_transpose()
return dst
elif func in _ops_to_preserve_subclass_in_fsdp2:
# Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
......
......@@ -90,6 +90,7 @@ class MXFP8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> MXFP8Tensor:
# Canonicalize tensor attributes
......@@ -105,24 +106,29 @@ class MXFP8Quantizer(Quantizer):
)
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
)
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......@@ -348,11 +354,17 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
)
if rowwise_matches and columnwise_matches:
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach())
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach())
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(
src._rowwise_scale_inv.detach(), *args[2:], **kwargs
)
if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach())
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
dst._columnwise_data.copy_(
src._columnwise_data.detach(), *args[2:], **kwargs
)
dst._columnwise_scale_inv.copy_(
src._columnwise_scale_inv.detach(), *args[2:], **kwargs
)
return dst
# FSDP2 related functions.
......
......@@ -6,7 +6,7 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import functools
import torch
......@@ -265,6 +265,7 @@ class NVFP4Quantizer(Quantizer):
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
) -> NVFP4Tensor:
......@@ -288,11 +289,18 @@ class NVFP4Quantizer(Quantizer):
scale_inv = None
amax_rowwise = None
if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device)
data = torch.empty(
self.convert_shape_for_fp4(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device)
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
columnwise_data = None
......@@ -306,12 +314,15 @@ class NVFP4Quantizer(Quantizer):
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor
return NVFP4Tensor(
......@@ -498,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
return self
raise ValueError("NVFP4Tensor does not support different memory formats!")
def get_usages(self) -> Dict[str, bool]:
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......@@ -520,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
)
if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise)
rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs)
rowwise_scale_inv = scale_inv_init_func(
tensor._rowwise_scale_inv, *args[1:], **kwargs
)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs)
else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise)
columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs)
columnwise_scale_inv = scale_inv_init_func(
tensor._columnwise_scale_inv, *args[1:], **kwargs
)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs)
else:
columnwise_data, columnwise_scale_inv, amax_columnwise = (
None,
......
......@@ -420,3 +420,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return
return
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
......@@ -225,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage):
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
usages = {"rowwise": self._data is not None}
if is_non_tn_fp8_gemm_supported():
usages["columnwise"] = self._data is not None
else:
usages["columnwise"] = self._transpose is not None and not self._transpose_invalid
return usages
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment