Unverified Commit beaecf84 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] NVIDIA-DL-Framework-Inspect support – part 1 – core (#1614)



* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* weight workspace fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* docs fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* file i forgot
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/debug/pytorch/utils.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* setup fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* setup fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* all tensor types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed check
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* move error
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* _reset
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* name documentation
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added blockwise quantizer
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make debug option optional
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* names fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 0994fb48
......@@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch>=2.1"])
install_reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package for numerical debugging."""
try:
from . import pytorch
from .pytorch.debug_state import set_weight_tensor_tp_group_reduce
except ImportError as e:
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains DebugQuantizer and DebugQuantizedTensor objects,
which are wrappers over Quantizer and QuantizedTensor.
These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
aten = torch.ops.aten
_tensor_to_gemm_names_map = {
"weight": ["fprop", "dgrad"],
"activation": ["fprop", "wgrad"],
"output": ["fprop", None],
"gradient": ["dgrad", "wgrad"],
"wgrad": ["wgrad", None],
"dgrad": ["dgrad", None],
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize"
HIGH_PRECISION = "High Precision"
class DebugQuantizer(Quantizer):
"""
DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect.
It allows adding custom calls inside the quantization process - which enables modifying tensors
or gathering tensor stats.
"""
def __init__(
self,
layer_name: str,
tensor_name: str,
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
#
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
self.inspect_tensor_enabled, self.rowwise_tensor_plan = (
self.get_plans_for_output_tensors()
)
else:
(
self.inspect_tensor_enabled,
self.inspect_tensor_postquantize_enabled_rowwise,
self.inspect_tensor_postquantize_enabled_columnwise,
) = self.get_enabled_look_at_tensors()
self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan()
self.log_messages_about_plans()
def get_plans_for_output_tensors(self) -> Tuple[bool, str]:
"""
Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the
API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support
gemm output in FP8.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
def get_enabled_look_at_tensors(self):
"""
Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
inspect_tensor_postquantize_enabled_rowwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
inspect_tensor_postquantize_enabled_columnwise,
)
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
rowwise_plan = None
columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
return rowwise_plan, columnwise_plan
def log_messages_about_plans(self):
"""
Logs the messages about the plans for each of the tensors.
"""
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -"
f" {self.rowwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name),
)
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -"
f" {self.columnwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name),
)
def _call_inspect_tensor_api(
self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None
):
import nvdlfw_inspect.api as debug_api
args = {
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count,
"tp_group": self.tp_group,
}
if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
if self.output_tensor:
return
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
args["rowwise"] = False
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None,
dtype: torch.dtype = None,
):
"""Returns DebugQuantizedTensor object."""
import nvdlfw_inspect.api as debug_api
assert not self.output_tensor
if out is not None:
return self.update_quantized(tensor, self)
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
if self.columnwise_tensor_plan == API_CALL_MODIFY:
columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION:
rowwise_gemm_tensor = tensor.to(dtype)
if self.columnwise_tensor_plan == HIGH_PRECISION:
columnwise_gemm_tensor = tensor.to(dtype)
self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor)
# sometimes we may want to return simple tensor with only rowwise_gemm
if self.tensor_name in ["wgrad", "dgrad", "output"]:
return rowwise_gemm_tensor
return DebugQuantizedTensor(
rowwise_gemm_tensor=rowwise_gemm_tensor,
columnwise_gemm_tensor=columnwise_gemm_tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)
def process_gemm_output(self, tensor: torch.Tensor):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
gemm=tensor_to_gemm[self.tensor_name],
tensor_name=self.tensor_name,
tensor=tensor,
iteration=self.iteration,
default_quantizer=self.parent_quantizer,
)
self._call_inspect_tensor_api(tensor)
return tensor
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Override make_empty() from Quantizer class."""
if self.parent_quantizer is not None:
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update quantized tensor - used in weight caching."""
import nvdlfw_inspect.api as debug_api
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
updated_rowwise_gemm = False
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None)
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None)
if self.columnwise_tensor_plan == API_CALL_MODIFY:
out = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.columnwise_gemm_tensor,
iteration=self.iteration,
)
assert out is None, (
"API call debug_api.transformer_engine.modify_tensor with out != None should"
" return None"
)
if self.rowwise_tensor_plan == API_CALL_MODIFY:
debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.rowwise_gemm_tensor,
iteration=self.iteration,
)
if self.rowwise_tensor_plan == HIGH_PRECISION:
dst.rowwise_gemm_tensor.copy_(src)
if self.columnwise_tensor_plan == HIGH_PRECISION:
# if they are the same tensor object, it is sufficient to update one
if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor:
dst.columnwise_gemm_tensor.copy_(src)
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
class DebugQuantizedTensor:
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
get_tensor().
"""
def __init__(
self,
rowwise_gemm_tensor,
columnwise_gemm_tensor,
quantizer,
layer_name=None,
tensor_name=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
self.columnwise_gemm_tensor = columnwise_gemm_tensor
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
def prepare_for_saving(self):
""" " Prepare for saving method override"""
self.tensors_to_save = (
[self.rowwise_gemm_tensor, self.columnwise_gemm_tensor]
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
return tensor_list, self
def restore_from_saved(self, tensors):
"""Restore from saved method override"""
tensor_objects_list, saved_tensors = restore_from_saved(
self.tensors_to_save,
tensors,
return_saved_tensors=True,
)
if len(tensor_objects_list) == 2:
# pylint: disable=unbalanced-tuple-unpacking
self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
""" " quantize_ method override"""
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
self.quantizer.update_quantized(tensor, self)
def dequantize(self, *, dtype=None):
""" " dequantize method override"""
if dtype is None:
dtype = self.rowwise_gemm_tensor.dtype
return self.rowwise_gemm_tensor.dequantize().to(dtype)
def get_tensor(self, transpose: bool):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self):
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool):
"""Update usage of the tensor."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Managing the state of all the debugged layers.
"""
import sys
class TEDebugState:
"""
A class to manage the state of debug layers.
"""
layer_count = 1
layers_initialized = {}
weight_tensor_tp_group_reduce = True
debug_enabled = None
@classmethod
def initialize(cls):
"""
If debug_api module is initialized, then sets cls.debug_enabled to True.
"""
if "nvdlfw_inspect" in sys.modules:
import nvdlfw_inspect.api as debug_api
if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None:
# This method is invoked when initializing TE modules.
# If this error is thrown, it means that some TE module had been initialized before
# debug_api was initialized, and now a new TE module is being initialized.
# This is likely to be a bug.
raise RuntimeError(
"[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before"
" initialization of the first TE module"
)
cls.debug_enabled = debug_api.DEBUG_MANAGER is not None
@classmethod
def _reset(cls):
"""Resets layer count and stats buffers."""
from ..features.utils.stats_buffer import STATS_BUFFERS
STATS_BUFFERS.reset()
cls.debug_enabled = None
cls.layers_initialized.clear()
@classmethod
def get_layer_count(cls):
"""
Layer counter is used when layer names are not provided to modules by the user.
"""
lc = cls.layer_count
cls.layer_count += 1
return lc
@classmethod
def set_weight_tensor_tp_group_reduce(cls, enabled):
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
TEDebugState.set_weight_tensor_tp_group_reduce(enabled)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils functions for the debug module."""
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
return any(q.any_feature_enabled() for q in quantizers)
......@@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion
import torch
import transformer_engine_torch as tex
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
......@@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -6561,6 +6564,7 @@ class MultiheadAttention(torch.nn.Module):
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd",
name: str = None,
) -> None:
super().__init__()
......@@ -6612,6 +6616,8 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group,
......@@ -6652,6 +6658,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_name="qkv",
name=name + ".layernorm_linear_qkv" if name is not None else None,
**common_gemm_kwargs,
)
else:
......@@ -6663,6 +6670,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=parameters_split,
name=name + ".linear_qkv" if name is not None else None,
**common_gemm_kwargs,
)
elif self.attention_type == "cross":
......@@ -6684,6 +6692,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_name="qkv",
name=name + ".layernorm_linear_q" if name is not None else None,
**common_gemm_kwargs,
)
else:
......@@ -6694,6 +6703,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
name=name + ".linear_q" if name is not None else None,
**common_gemm_kwargs,
)
self.key_value = Linear(
......@@ -6704,6 +6714,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key", "value") if not fuse_qkv_params else None,
name=name + ".linear_kv" if name is not None else None,
**common_gemm_kwargs,
)
......@@ -6733,6 +6744,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
ub_name="proj",
name=name + ".proj" if name is not None else None,
**common_gemm_kwargs,
)
......@@ -6923,6 +6935,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
......
......@@ -14,6 +14,7 @@ from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [
"general_gemm",
......@@ -109,6 +110,13 @@ def general_gemm(
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
quantization_params = quantization_params.parent_quantizer
A = A.get_tensor(not transa)
B = B.get_tensor(transb)
# Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
......@@ -145,6 +153,9 @@ def general_gemm(
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
reset_swizzled_inputs(A, B, original_scale_inverses)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
return out, bias_grad, gelu_input, extra_output
......
......@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
......@@ -29,6 +29,7 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -1195,6 +1196,28 @@ def gather_along_first_dim(
out_shape=out_shape,
)
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors
if quantizer is not None:
warnings.warn(
......
......@@ -10,6 +10,7 @@ import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
import logging
from types import MethodType
import torch
......@@ -39,6 +40,9 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
......@@ -432,6 +437,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
......@@ -848,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if not ctx.fp8 and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag:
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
......@@ -858,6 +866,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
# Also supports debug quantization, which is handled inside gather_along_first_dim.
if gather_grad_output:
grad_bias = None
if ctx.use_bias:
......@@ -886,6 +895,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
return grad_output, grad_bias
# Debug without all-gather: unfused cast and bgrad
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug:
grad_output_ = quantizer(grad_output)
if (
isinstance(
grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias = None
grad_output = grad_output_
return grad_output, grad_bias
# FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None
if ctx.use_bias:
......@@ -1002,6 +1028,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values
......@@ -1024,6 +1051,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
over `update_workspace` if provided.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
"""
# FP8 primary weights
......@@ -1037,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
......@@ -1047,6 +1078,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out = None
del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
......@@ -1064,7 +1100,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
out = quantizer(tensor)
out = quantizer.quantize(tensor, dtype=workspace_dtype)
# Update cache
if cache_name is not None:
......@@ -1081,7 +1117,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, out, skip_update_flag)
return out
def _load_from_state_dict(
......@@ -1104,3 +1139,47 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def _validate_name(self):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
"""
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
......@@ -35,6 +35,7 @@ from ..utils import (
nvtx_range_pop,
nvtx_range_push,
requires_grad,
needs_quantized_gemm,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -56,6 +57,8 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
......@@ -90,8 +93,9 @@ class _LayerNormLinear(torch.autograd.Function):
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
......@@ -116,6 +120,7 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
......@@ -214,12 +219,12 @@ class _LayerNormLinear(torch.autograd.Function):
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
if fp8 or debug:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total)
else:
if fp8:
if fp8 or debug:
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -233,18 +238,19 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if fp8 else None),
quantizer=(input_quantizer if fp8 or debug else None),
)
else:
if fp8 and not with_quantized_norm:
if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype
if not fp8:
weightmat = weight
quantized_weight = False
weightmat = cast_if_needed(weight, activation_dtype)
if not fp8 and not debug:
weightmat = cast_if_needed(weightmat, activation_dtype)
else:
quantized_weight = not isinstance(weight, QuantizedTensor)
......@@ -254,6 +260,7 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
tensor=weight,
quantizer=weight_quantizer,
......@@ -261,11 +268,12 @@ class _LayerNormLinear(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
# Cast bias to expected dtype
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
......@@ -400,6 +408,7 @@ class _LayerNormLinear(torch.autograd.Function):
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.input_quantizer = input_quantizer
ctx.owns_input = inputmat is not inp
......@@ -434,6 +443,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.debug = debug
# Row Parallel Linear
if ub_overlap_rs_fprop:
......@@ -611,7 +621,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -757,6 +767,7 @@ class _LayerNormLinear(torch.autograd.Function):
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
......@@ -865,8 +876,9 @@ class _LayerNormLinear(torch.autograd.Function):
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading
None, # tp_group
None, # tp_size
......@@ -889,6 +901,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_bulk_wgrad
None, # ub_name
None, # fsdp_group
None, # debug
None, # module
None, # skip_fp8_weight_update
)
......@@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -1007,6 +1022,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None,
name: str = None,
) -> None:
super().__init__()
......@@ -1023,6 +1039,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -1312,6 +1332,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1348,13 +1371,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad)
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
......@@ -1376,8 +1414,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
......@@ -1402,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fsdp_group,
self,
skip_fp8_weight_update,
debug,
)
out = fwd_fn(*args)
......@@ -1421,8 +1461,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8:
return [None] * 5
return [None] * 6
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
......@@ -1441,8 +1482,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
)
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
......@@ -41,6 +41,7 @@ from ..utils import (
clear_tensor_data,
requires_grad,
non_tn_fp8_gemm_supported,
needs_quantized_gemm,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -73,6 +74,8 @@ from ..tensor.quantized_tensor import (
from ..cpp_extensions import (
general_gemm,
)
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["LayerNormMLP"]
......@@ -153,12 +156,16 @@ class _LayerNormMLP(torch.autograd.Function):
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_fc2_output_quantizer: Optional[Quantizer],
grad_fc1_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
fc2_output_quantizer: Optional[Quantizer],
fc2_grad_input_quantizer: Optional[Quantizer],
fc2_grad_weight_quantizer: Optional[Quantizer],
fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
......@@ -184,6 +191,7 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
......@@ -212,9 +220,16 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# Avoid quantized norm kernel if norm output will be returned
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered
fp8
and not return_layernorm_output
and not return_layernorm_output_gathered
and not debug
)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
......@@ -270,13 +285,13 @@ class _LayerNormMLP(torch.autograd.Function):
# norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8:
if fp8 or debug:
if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
if fp8:
if fp8 or debug:
if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -290,21 +305,21 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if fp8 else None),
quantizer=(fc1_input_quantizer if fp8 or debug else None),
)
else:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather:
if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
# Cast weights to expected dtype
if not fp8:
fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype)
else:
fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight
if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# FP8 cast to workspace buffer
......@@ -316,6 +331,7 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace(
......@@ -325,11 +341,15 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
# Cast biases to expected dtype
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16
if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
......@@ -359,13 +379,16 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_gelu_fusion = True
if gemm_gelu_fusion and bias_gelu_fusion:
gemm_gelu_fusion = False
if debug:
gemm_gelu_fusion = False
fc1_outputs = general_gemm(
fc1_weight_final,
ln_out_total,
get_workspace(),
quantization_params=(
fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8
fc2_input_quantizer
if gemm_gelu_fusion
else fc1_output_quantizer # fused gelu output is in fp8
),
out_dtype=activation_dtype,
bias=(
......@@ -376,6 +399,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
)
if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total)
......@@ -389,6 +413,10 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias)
elif gemm_gelu_fusion:
act_out, _, fc1_out, _ = fc1_outputs
elif debug:
fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, None)
act_out = fc2_input_quantizer(act_out)
else:
fc1_out, *_ = fc1_outputs
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
......@@ -426,7 +454,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
out_dtype=activation_dtype,
bias=fc2_bias,
quantization_params=output_quantizer,
quantization_params=fc2_output_quantizer,
out=fc2_out,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_obj_fc2out,
......@@ -515,11 +543,14 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer
ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.fc2_input_quantizer = fc2_input_quantizer
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer
ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer
ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer
ctx.fc1_input_quantizer = fc1_input_quantizer
ctx.fc2_input_quantizer = fc2_input_quantizer
ctx.fc1_weight_requires_grad = fc1_weight.requires_grad
ctx.fc2_weight_requires_grad = fc2_weight.requires_grad
......@@ -552,6 +583,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_overlap_ag = ub_overlap_ag
ctx.debug = debug
ctx.requires_dgrad = (
inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad
......@@ -675,18 +707,18 @@ class _LayerNormMLP(torch.autograd.Function):
# Configure quantizer for FC2 grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_fc2_output_quantizer is not None:
if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_fc2_output_quantizer,
ctx.fc2_grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_fc2_output_quantizer.set_usage(
ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
......@@ -701,7 +733,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad_output,
fc2_bias_grad,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer
ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer
)
# Launch tensor-parallel communication for FC1 GEMM input
......@@ -714,7 +746,7 @@ class _LayerNormMLP(torch.autograd.Function):
and not ctx.ub_bulk_dgrad
):
quantizer = None
if ctx.fp8:
if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -747,7 +779,10 @@ class _LayerNormMLP(torch.autograd.Function):
# 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm
# 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm
fc2_dgrad_gemm_gelu_fusion = (
not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion)
not ctx.fp8
and (ctx.activation == "gelu")
and (not ctx.bias_gelu_fusion)
and (not ctx.debug)
)
# FC2 DGRAD; Unconditional
......@@ -763,7 +798,9 @@ class _LayerNormMLP(torch.autograd.Function):
layout="NN",
grad=True,
quantization_params=(
ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None
ctx.fc1_grad_input_quantizer
if fc2_dgrad_gemm_gelu_fusion or ctx.debug
else None
), # high precision to activation
out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion,
......@@ -798,7 +835,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
quantization_params=None, # wgrad in high precision
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT",
grad=grad_arg,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
......@@ -817,15 +854,20 @@ class _LayerNormMLP(torch.autograd.Function):
# bias computation
fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False
if ctx.grad_fc1_output_quantizer is not None:
ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.fc1_grad_output_quantizer is not None:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.bias_gelu_fusion:
# Fusion: gemm, bias + gelu
assert ctx.activation == "gelu"
assert not ctx.fp8
fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias)
if ctx.grad_fc1_output_quantizer is not None:
dact = ctx.grad_fc1_output_quantizer(dact)
if ctx.fc1_grad_output_quantizer is not None:
dact = ctx.fc1_grad_output_quantizer(dact)
elif ctx.debug:
dact_func = _act_func(ctx.activation)[1]
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None)
fc1_bias_grad = dact.sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
elif (
_act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None
and ctx.fp8
......@@ -835,7 +877,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2]
fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer
) # quantize bgrad gelu fused
else:
# Fusion: gemm + gelu,
......@@ -849,12 +891,12 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fp8:
# TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer):
if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.grad_fc1_output_quantizer(dact)
dact = ctx.fc1_grad_output_quantizer(dact)
else:
fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.grad_fc1_output_quantizer
dact, ctx.fc1_grad_output_quantizer
)
else:
fuse_gemm_and_bias_fc1_wgrad = (
......@@ -915,6 +957,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
out=fc1_dgrad_bulk,
out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN",
grad=True,
ub=ub_obj_fc1_dgrad,
......@@ -990,6 +1033,7 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype
),
layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad,
......@@ -1123,14 +1167,18 @@ class _LayerNormMLP(torch.autograd.Function):
None, # fp8
None, # fp8_calibration
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer
None, # fc1_weight_quantizer
None, # fc2_input_quantizer
None, # fc2_weight_quantizer
None, # output_quantizer
None, # grad_fc2_output_quantizer
None, # grad_fc1_output_quantizer
None, # grad_input_quantizer
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
None, # fc1_output_quantizer,
None, # fc1_grad_input_quantizer,
None, # fc1_grad_weight_quantizer,
None, # fc1_grad_output_quantizer,
None, # fc2_input_quantizer,
None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading
None, # tp_group
None, # tp_size
......@@ -1156,6 +1204,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # debug
)
......@@ -1208,6 +1257,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -1277,6 +1328,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
name: str = None,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
......@@ -1306,6 +1358,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
)
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None:
self.tp_size = tp_size
......@@ -1466,7 +1522,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo()
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
......@@ -1489,6 +1547,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......@@ -1503,17 +1564,35 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = (
self._get_quantizers(fp8_output)
if not debug
else self._get_debug_quantizers(fp8_output)
)
if debug:
if not any_feature_enabled(quantizers):
quantizers = self._get_quantizers(fp8_output)
debug = False
if isinstance(self.fc1_weight, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers(fp8_output)
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers
# Get weight tensors
fc1_weight = self.fc1_weight
......@@ -1551,12 +1630,16 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
......@@ -1565,7 +1648,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
......@@ -1578,10 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
debug,
)
out = fwd_fn(*args)
......@@ -1603,13 +1687,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
) = [None] * 8
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = [None] * 12
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary
......@@ -1623,30 +1711,54 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT]
fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
]
if torch.is_grad_enabled():
grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
]
grad_fc2_output_quantizer.internal = True
grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][
fc2_grad_output_quantizer.internal = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
]
grad_fc1_output_quantizer.internal = True
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2]
grad_input_quantizer.internal = True
fc1_grad_output_quantizer.internal = True
return (
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
output_quantizer,
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output):
from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output))
assert TEDebugState.debug_enabled
def make_debug(prefix, offset):
labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return [
DebugQuantizer(
f"{self.name}.{prefix}",
label,
None if label in ("dgrad", "wgrad") else base_quantizers[i + offset],
self.tp_group,
)
for i, label in enumerate(labels)
]
return tuple(make_debug("fc1", 0) + make_debug("fc2", 6))
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert (
......@@ -1691,14 +1803,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
......@@ -1706,7 +1818,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_INPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
......
......@@ -28,11 +28,12 @@ from ..utils import (
clear_tensor_data,
divide,
init_method_constant,
requires_grad,
needs_quantized_gemm,
non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -62,6 +63,8 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"]
......@@ -84,8 +87,9 @@ class _Linear(torch.autograd.Function):
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
......@@ -106,6 +110,7 @@ class _Linear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -144,7 +149,7 @@ class _Linear(torch.autograd.Function):
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl:
......@@ -196,9 +201,9 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype
if not fp8:
weightmat = cast_if_needed(weight, activation_dtype)
else:
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
columnwise_usage = is_grad_enabled and inp.requires_grad
......@@ -208,7 +213,6 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
......@@ -218,11 +222,14 @@ class _Linear(torch.autograd.Function):
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
else:
weightmat = cast_if_needed(weightmat, activation_dtype)
# Cast bias to expected dtype
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
......@@ -343,12 +350,14 @@ class _Linear(torch.autograd.Function):
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad
ctx.debug = debug
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None
......@@ -528,7 +537,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None
if ctx.fp8:
if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -564,7 +573,6 @@ class _Linear(torch.autograd.Function):
# Update quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
......@@ -678,6 +686,7 @@ class _Linear(torch.autograd.Function):
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
......@@ -753,8 +762,9 @@ class _Linear(torch.autograd.Function):
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
......@@ -775,6 +785,7 @@ class _Linear(torch.autograd.Function):
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # debug
)
......@@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -871,6 +884,7 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -883,6 +897,10 @@ class Linear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device."
......@@ -1126,6 +1144,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
......@@ -1161,13 +1183,28 @@ class Linear(TransformerEngineBaseModule):
else:
bias_tensor = None
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad)
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
......@@ -1191,8 +1228,9 @@ class Linear(TransformerEngineBaseModule):
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
......@@ -1213,6 +1251,7 @@ class Linear(TransformerEngineBaseModule):
self.fsdp_group,
self,
skip_fp8_weight_update,
debug,
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
......@@ -1224,8 +1263,9 @@ class Linear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8:
return [None] * 5
return [None] * 6
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
......@@ -1243,8 +1283,20 @@ class Linear(TransformerEngineBaseModule):
input_quantizer,
weight_quantizer,
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
)
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
......@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype):
torch.nn.Module.float = _make_module_cast_func(torch.float32)
torch.nn.Module.half = _make_module_cast_func(torch.float16)
torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16)
def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
all_tensor_types = [
torch.Tensor,
torch.nn.Parameter,
Float8Tensor,
Float8TensorBase,
MXFP8Tensor,
MXFP8TensorBase,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
]
return all_tensor_types
......@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function):
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
dtype = torch_to_transformer_engine_dtype[dtype]
te_dtype = torch_to_transformer_engine_dtype[dtype]
# Make sure FP8 data is in expected format
if tensor._data is not None:
if tensor._data.numel() == 0:
return torch.empty_like(tensor._data, dtype=dtype)
# Cast from FP8
return tex.dequantize(tensor, dtype)
return tex.dequantize(tensor, te_dtype)
raise NotImplementedError("Casting back from the transpose not implemented yet!")
......
......@@ -37,7 +37,8 @@ def prepare_for_saving(
def restore_from_saved(
tensors: list[Optional[Any]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
) -> list[Optional[Any]]:
return_saved_tensors: bool = False,
) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]:
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
for tensor in tensors:
......@@ -47,6 +48,9 @@ def restore_from_saved(
else:
saved_tensors = tensor.restore_from_saved(saved_tensors)
tensor_objects.append(tensor)
if return_saved_tensors:
return tensor_objects, saved_tensors
return tensor_objects
......@@ -113,7 +117,11 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place"""
def quantize(
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None
self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override
) -> QuantizedTensor:
"""Quantize tensor"""
if out is not None:
......
......@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention import (
MultiheadAttention,
)
......@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type,
)
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module):
head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module):
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
name: str = None,
) -> None:
super().__init__()
......@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module):
self.attn_input_format = attn_input_format
self.name = name
attention_args = (
hidden_size,
num_attention_heads,
......@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
name=name + ".self_attention" if name is not None else None,
)
if layer_type == "decoder":
......@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=True,
normalization=normalization,
device=device,
name=name + ".inter_attention" if name is not None else None,
)
# LayerNorm -> activation(Linear + Bias) -> Linear
......@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module):
activation=activation,
normalization=normalization,
device=device,
name=name + ".layernorm_mlp" if name is not None else None,
)
self.hidden_dropout = hidden_dropout
......@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
......
......@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch
import transformer_engine.pytorch.cpp_extensions as ext
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor
......@@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple):
return ((value + multiple - 1) // multiple) * multiple
def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm."""
if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor,
torch.nn.Parameter,
]
return type(obj) not in [
torch.Tensor,
torch.nn.Parameter,
] # pylint: disable=unidiomatic-typecheck
@functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool:
"""Check if NVTX range profiling is enabled"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment