Unverified Commit beaecf84 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] NVIDIA-DL-Framework-Inspect support – part 1 – core (#1614)



* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* weight workspace fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* docs fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* file i forgot
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/debug/pytorch/utils.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* setup fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* setup fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* all tensor types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed check
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* move error
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* _reset
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* name documentation
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added blockwise quantizer
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make debug option optional
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* names fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 0994fb48
...@@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch>=2.1"]) install_reqs.extend(["torch>=2.1"])
install_reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build # Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package for numerical debugging."""
try:
from . import pytorch
from .pytorch.debug_state import set_weight_tensor_tp_group_reduce
except ImportError as e:
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains DebugQuantizer and DebugQuantizedTensor objects,
which are wrappers over Quantizer and QuantizedTensor.
These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
aten = torch.ops.aten
_tensor_to_gemm_names_map = {
"weight": ["fprop", "dgrad"],
"activation": ["fprop", "wgrad"],
"output": ["fprop", None],
"gradient": ["dgrad", "wgrad"],
"wgrad": ["wgrad", None],
"dgrad": ["dgrad", None],
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize"
HIGH_PRECISION = "High Precision"
class DebugQuantizer(Quantizer):
"""
DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect.
It allows adding custom calls inside the quantization process - which enables modifying tensors
or gathering tensor stats.
"""
def __init__(
self,
layer_name: str,
tensor_name: str,
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
#
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
self.inspect_tensor_enabled, self.rowwise_tensor_plan = (
self.get_plans_for_output_tensors()
)
else:
(
self.inspect_tensor_enabled,
self.inspect_tensor_postquantize_enabled_rowwise,
self.inspect_tensor_postquantize_enabled_columnwise,
) = self.get_enabled_look_at_tensors()
self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan()
self.log_messages_about_plans()
def get_plans_for_output_tensors(self) -> Tuple[bool, str]:
"""
Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the
API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support
gemm output in FP8.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
def get_enabled_look_at_tensors(self):
"""
Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
inspect_tensor_postquantize_enabled_rowwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
inspect_tensor_postquantize_enabled_columnwise,
)
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
rowwise_plan = None
columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
return rowwise_plan, columnwise_plan
def log_messages_about_plans(self):
"""
Logs the messages about the plans for each of the tensors.
"""
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -"
f" {self.rowwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name),
)
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -"
f" {self.columnwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name),
)
def _call_inspect_tensor_api(
self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None
):
import nvdlfw_inspect.api as debug_api
args = {
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count,
"tp_group": self.tp_group,
}
if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
if self.output_tensor:
return
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
args["rowwise"] = False
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None,
dtype: torch.dtype = None,
):
"""Returns DebugQuantizedTensor object."""
import nvdlfw_inspect.api as debug_api
assert not self.output_tensor
if out is not None:
return self.update_quantized(tensor, self)
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
if self.columnwise_tensor_plan == API_CALL_MODIFY:
columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION:
rowwise_gemm_tensor = tensor.to(dtype)
if self.columnwise_tensor_plan == HIGH_PRECISION:
columnwise_gemm_tensor = tensor.to(dtype)
self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor)
# sometimes we may want to return simple tensor with only rowwise_gemm
if self.tensor_name in ["wgrad", "dgrad", "output"]:
return rowwise_gemm_tensor
return DebugQuantizedTensor(
rowwise_gemm_tensor=rowwise_gemm_tensor,
columnwise_gemm_tensor=columnwise_gemm_tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)
def process_gemm_output(self, tensor: torch.Tensor):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
gemm=tensor_to_gemm[self.tensor_name],
tensor_name=self.tensor_name,
tensor=tensor,
iteration=self.iteration,
default_quantizer=self.parent_quantizer,
)
self._call_inspect_tensor_api(tensor)
return tensor
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Override make_empty() from Quantizer class."""
if self.parent_quantizer is not None:
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update quantized tensor - used in weight caching."""
import nvdlfw_inspect.api as debug_api
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
updated_rowwise_gemm = False
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None)
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None)
if self.columnwise_tensor_plan == API_CALL_MODIFY:
out = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.columnwise_gemm_tensor,
iteration=self.iteration,
)
assert out is None, (
"API call debug_api.transformer_engine.modify_tensor with out != None should"
" return None"
)
if self.rowwise_tensor_plan == API_CALL_MODIFY:
debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.rowwise_gemm_tensor,
iteration=self.iteration,
)
if self.rowwise_tensor_plan == HIGH_PRECISION:
dst.rowwise_gemm_tensor.copy_(src)
if self.columnwise_tensor_plan == HIGH_PRECISION:
# if they are the same tensor object, it is sufficient to update one
if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor:
dst.columnwise_gemm_tensor.copy_(src)
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
class DebugQuantizedTensor:
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
get_tensor().
"""
def __init__(
self,
rowwise_gemm_tensor,
columnwise_gemm_tensor,
quantizer,
layer_name=None,
tensor_name=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
self.columnwise_gemm_tensor = columnwise_gemm_tensor
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
def prepare_for_saving(self):
""" " Prepare for saving method override"""
self.tensors_to_save = (
[self.rowwise_gemm_tensor, self.columnwise_gemm_tensor]
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
return tensor_list, self
def restore_from_saved(self, tensors):
"""Restore from saved method override"""
tensor_objects_list, saved_tensors = restore_from_saved(
self.tensors_to_save,
tensors,
return_saved_tensors=True,
)
if len(tensor_objects_list) == 2:
# pylint: disable=unbalanced-tuple-unpacking
self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
""" " quantize_ method override"""
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
self.quantizer.update_quantized(tensor, self)
def dequantize(self, *, dtype=None):
""" " dequantize method override"""
if dtype is None:
dtype = self.rowwise_gemm_tensor.dtype
return self.rowwise_gemm_tensor.dequantize().to(dtype)
def get_tensor(self, transpose: bool):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self):
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool):
"""Update usage of the tensor."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Managing the state of all the debugged layers.
"""
import sys
class TEDebugState:
"""
A class to manage the state of debug layers.
"""
layer_count = 1
layers_initialized = {}
weight_tensor_tp_group_reduce = True
debug_enabled = None
@classmethod
def initialize(cls):
"""
If debug_api module is initialized, then sets cls.debug_enabled to True.
"""
if "nvdlfw_inspect" in sys.modules:
import nvdlfw_inspect.api as debug_api
if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None:
# This method is invoked when initializing TE modules.
# If this error is thrown, it means that some TE module had been initialized before
# debug_api was initialized, and now a new TE module is being initialized.
# This is likely to be a bug.
raise RuntimeError(
"[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before"
" initialization of the first TE module"
)
cls.debug_enabled = debug_api.DEBUG_MANAGER is not None
@classmethod
def _reset(cls):
"""Resets layer count and stats buffers."""
from ..features.utils.stats_buffer import STATS_BUFFERS
STATS_BUFFERS.reset()
cls.debug_enabled = None
cls.layers_initialized.clear()
@classmethod
def get_layer_count(cls):
"""
Layer counter is used when layer names are not provided to modules by the user.
"""
lc = cls.layer_count
cls.layer_count += 1
return lc
@classmethod
def set_weight_tensor_tp_group_reduce(cls, enabled):
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
TEDebugState.set_weight_tensor_tp_group_reduce(enabled)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils functions for the debug module."""
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
return any(q.any_feature_enabled() for q in quantizers)
...@@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion ...@@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
nvtx_range_pop, nvtx_range_pop,
...@@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -6561,6 +6564,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6561,6 +6564,7 @@ class MultiheadAttention(torch.nn.Module):
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -6612,6 +6616,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6612,6 +6616,8 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group, "tp_group": tp_group,
...@@ -6652,6 +6658,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6652,6 +6658,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
name=name + ".layernorm_linear_qkv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -6663,6 +6670,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6663,6 +6670,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=parameters_split, parameters_split=parameters_split,
name=name + ".linear_qkv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
elif self.attention_type == "cross": elif self.attention_type == "cross":
...@@ -6684,6 +6692,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6684,6 +6692,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
name=name + ".layernorm_linear_q" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -6694,6 +6703,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6694,6 +6703,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
name=name + ".linear_q" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
self.key_value = Linear( self.key_value = Linear(
...@@ -6704,6 +6714,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6704,6 +6714,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("key", "value") if not fuse_qkv_params else None, parameters_split=("key", "value") if not fuse_qkv_params else None,
name=name + ".linear_kv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -6733,6 +6744,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6733,6 +6744,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs, ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
ub_name="proj", ub_name="proj",
name=name + ".proj" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -6923,6 +6935,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6923,6 +6935,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# ================================================= # =================================================
# Pre-allocate memory for key-value cache for inference # Pre-allocate memory for key-value cache for inference
# ================================================= # =================================================
......
...@@ -14,6 +14,7 @@ from ..utils import get_sm_count ...@@ -14,6 +14,7 @@ from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
...@@ -109,6 +110,13 @@ def general_gemm( ...@@ -109,6 +110,13 @@ def general_gemm(
if not out.is_contiguous(): if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.") raise ValueError("Output tensor is not contiguous.")
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
quantization_params = quantization_params.parent_quantizer
A = A.get_tensor(not transa)
B = B.get_tensor(transb)
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
...@@ -145,6 +153,9 @@ def general_gemm( ...@@ -145,6 +153,9 @@ def general_gemm(
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
reset_swizzled_inputs(A, B, original_scale_inverses) reset_swizzled_inputs(A, B, original_scale_inverses)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
return out, bias_grad, gelu_input, extra_output return out, bias_grad, gelu_input, extra_output
......
...@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
...@@ -29,6 +29,7 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer ...@@ -29,6 +29,7 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -1195,6 +1196,28 @@ def gather_along_first_dim( ...@@ -1195,6 +1196,28 @@ def gather_along_first_dim(
out_shape=out_shape, out_shape=out_shape,
) )
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors # High-precision communication for quantized tensors
if quantizer is not None: if quantizer is not None:
warnings.warn( warnings.warn(
......
...@@ -10,6 +10,7 @@ import warnings ...@@ -10,6 +10,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
import logging
from types import MethodType from types import MethodType
import torch import torch
...@@ -39,6 +40,9 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer ...@@ -39,6 +40,9 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub"]
...@@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA." assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.fp8_initialized = False self.fp8_initialized = False
self.fp8 = False self.fp8 = False
self.fp8_calibration = False self.fp8_calibration = False
...@@ -432,6 +437,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -432,6 +437,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
# Names of attributes that can be set quickly (see __setattr__ # Names of attributes that can be set quickly (see __setattr__
# method) # method)
_fast_setattr_names: Set[str] = { _fast_setattr_names: Set[str] = {
...@@ -848,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -848,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# Non-FP8 case: bgrad is fused with wgrad for this case. # Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8 and not ctx.debug:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
...@@ -858,6 +866,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -858,6 +866,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output, None return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose # FP8 with all-gather: unfused bgrad, fused cast + transpose
# Also supports debug quantization, which is handled inside gather_along_first_dim.
if gather_grad_output: if gather_grad_output:
grad_bias = None grad_bias = None
if ctx.use_bias: if ctx.use_bias:
...@@ -886,6 +895,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -886,6 +895,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
return grad_output, grad_bias return grad_output, grad_bias
# Debug without all-gather: unfused cast and bgrad
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug:
grad_output_ = quantizer(grad_output)
if (
isinstance(
grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias = None
grad_output = grad_output_
return grad_output, grad_bias
# FP8 without all-gather: fused bgrad + cast + transpose # FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None grad_bias = None
if ctx.use_bias: if ctx.use_bias:
...@@ -1002,6 +1028,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1002,6 +1028,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
update_workspace: bool = True, update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None, skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None, fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values """Get FP8 workspace buffer and maybe update its values
...@@ -1024,6 +1051,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1024,6 +1051,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
over `update_workspace` if provided. over `update_workspace` if provided.
fsdp_group: bool, default = None fsdp_group: bool, default = None
FSDP process group that the weights are distributed over. FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
""" """
# FP8 primary weights # FP8 primary weights
...@@ -1037,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1037,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache # Try getting workspace from cache
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase): if quantizer is not None and isinstance(out, MXFP8TensorBase):
...@@ -1047,6 +1078,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1047,6 +1078,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out = None out = None
del self._fp8_workspaces[cache_name] del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights. # for models initialized with Fp8 primary weights.
...@@ -1064,7 +1100,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1064,7 +1100,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError( raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace" "tensor and quantizer kwargs must be provided to construct FP8 workspace"
) )
out = quantizer(tensor) out = quantizer.quantize(tensor, dtype=workspace_dtype)
# Update cache # Update cache
if cache_name is not None: if cache_name is not None:
...@@ -1081,7 +1117,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1081,7 +1117,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out.quantize_(tensor, noop_flag=skip_update_flag) out.quantize_(tensor, noop_flag=skip_update_flag)
else: else:
tex.quantize(tensor, quantizer, out, skip_update_flag) tex.quantize(tensor, quantizer, out, skip_update_flag)
return out return out
def _load_from_state_dict( def _load_from_state_dict(
...@@ -1104,3 +1139,47 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1104,3 +1139,47 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super()._load_from_state_dict( super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )
def _validate_name(self):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
"""
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
...@@ -35,6 +35,7 @@ from ..utils import ( ...@@ -35,6 +35,7 @@ from ..utils import (
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
needs_quantized_gemm,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -56,6 +57,8 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +57,8 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -90,8 +93,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -90,8 +93,9 @@ class _LayerNormLinear(torch.autograd.Function):
input_quantizer: Optional[Quantizer], input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
...@@ -116,6 +120,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -116,6 +120,7 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -214,12 +219,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -214,12 +219,12 @@ class _LayerNormLinear(torch.autograd.Function):
# norm output will be returned # norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8: if fp8 or debug:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
if fp8: if fp8 or debug:
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -233,18 +238,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -233,18 +238,19 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(input_quantizer if fp8 else None), quantizer=(input_quantizer if fp8 or debug else None),
) )
else: else:
if fp8 and not with_quantized_norm: if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
if not fp8: weightmat = weight
quantized_weight = False quantized_weight = False
weightmat = cast_if_needed(weight, activation_dtype) if not fp8 and not debug:
weightmat = cast_if_needed(weightmat, activation_dtype)
else: else:
quantized_weight = not isinstance(weight, QuantizedTensor) quantized_weight = not isinstance(weight, QuantizedTensor)
...@@ -254,6 +260,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -254,6 +260,7 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
tensor=weight, tensor=weight,
quantizer=weight_quantizer, quantizer=weight_quantizer,
...@@ -261,11 +268,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -261,11 +268,12 @@ class _LayerNormLinear(torch.autograd.Function):
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
) )
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
...@@ -400,6 +408,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -400,6 +408,7 @@ class _LayerNormLinear(torch.autograd.Function):
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad ctx.main_grad = weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.owns_input = inputmat is not inp ctx.owns_input = inputmat is not inp
...@@ -434,6 +443,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -434,6 +443,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.debug = debug
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
...@@ -611,7 +621,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -611,7 +621,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None quantizer = None
if ctx.fp8: if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -757,6 +767,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -757,6 +767,7 @@ class _LayerNormLinear(torch.autograd.Function):
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
extra_output=rs_out, extra_output=rs_out,
...@@ -865,8 +876,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -865,8 +876,9 @@ class _LayerNormLinear(torch.autograd.Function):
None, # input_quantizer None, # input_quantizer
None, # weight_quantizer None, # weight_quantizer
None, # output_quantizer None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading None, # cpu_offloading
None, # tp_group None, # tp_group
None, # tp_size None, # tp_size
...@@ -889,6 +901,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -889,6 +901,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_bulk_wgrad None, # ub_bulk_wgrad
None, # ub_name None, # ub_name
None, # fsdp_group None, # fsdp_group
None, # debug
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
) )
...@@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -1007,6 +1022,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1007,6 +1022,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1023,6 +1039,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1023,6 +1039,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output_gathered = return_layernorm_output_gathered self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1312,6 +1332,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1312,6 +1332,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1348,13 +1371,28 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1348,13 +1371,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad) grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
...@@ -1376,8 +1414,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1376,8 +1414,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
...@@ -1402,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1402,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
debug,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -1421,8 +1461,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1421,8 +1461,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8: if not self.fp8:
return [None] * 5 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
...@@ -1441,8 +1482,20 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1441,8 +1482,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
...@@ -41,6 +41,7 @@ from ..utils import ( ...@@ -41,6 +41,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
non_tn_fp8_gemm_supported, non_tn_fp8_gemm_supported,
needs_quantized_gemm,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -73,6 +74,8 @@ from ..tensor.quantized_tensor import ( ...@@ -73,6 +74,8 @@ from ..tensor.quantized_tensor import (
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
) )
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -153,12 +156,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -153,12 +156,16 @@ class _LayerNormMLP(torch.autograd.Function):
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer], fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer], fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer], fc2_weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer], fc2_output_quantizer: Optional[Quantizer],
grad_fc2_output_quantizer: Optional[Quantizer], fc2_grad_input_quantizer: Optional[Quantizer],
grad_fc1_output_quantizer: Optional[Quantizer], fc2_grad_weight_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer], fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
...@@ -184,6 +191,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -184,6 +191,7 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -212,9 +220,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -212,9 +220,16 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
# Avoid quantized norm kernel if norm output will be returned # for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = ( with_quantized_norm = (
fp8 and not return_layernorm_output and not return_layernorm_output_gathered fp8
and not return_layernorm_output
and not return_layernorm_output_gathered
and not debug
) )
if isinstance(fc1_input_quantizer, Float8BlockQuantizer): if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion. # Kernels not available for norm fusion.
...@@ -270,13 +285,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -270,13 +285,13 @@ class _LayerNormMLP(torch.autograd.Function):
# norm output will be returned # norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8: if fp8 or debug:
if not force_hp_fc1_input_gather: if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
if fp8: if fp8 or debug:
if not with_quantized_norm and not force_hp_fc1_input_gather: if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -290,21 +305,21 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -290,21 +305,21 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(fc1_input_quantizer if fp8 else None), quantizer=(fc1_input_quantizer if fp8 or debug else None),
) )
else: else:
# NOTE: force_hp_fc1_input_gather is redundant with else, but # NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs # here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp. # to gather in hp.
if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
# Cast weights to expected dtype # Cast weights to expected dtype
if not fp8: fc1_weight_final = fc1_weight
fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) fc2_weight_final = fc2_weight
fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype)
else: if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace, # If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc. # which handles weight caching etc.
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
...@@ -316,6 +331,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -316,6 +331,7 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
) )
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace( fc2_weight_final = module.get_weight_workspace(
...@@ -325,11 +341,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -325,11 +341,15 @@ class _LayerNormMLP(torch.autograd.Function):
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
) )
else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
# Cast biases to expected dtype # Cast biases to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
if fc1_bias is not None: if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
...@@ -359,13 +379,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -359,13 +379,16 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_gelu_fusion = True gemm_gelu_fusion = True
if gemm_gelu_fusion and bias_gelu_fusion: if gemm_gelu_fusion and bias_gelu_fusion:
gemm_gelu_fusion = False gemm_gelu_fusion = False
if debug:
gemm_gelu_fusion = False
fc1_outputs = general_gemm( fc1_outputs = general_gemm(
fc1_weight_final, fc1_weight_final,
ln_out_total, ln_out_total,
get_workspace(), get_workspace(),
quantization_params=( quantization_params=(
fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 fc2_input_quantizer
if gemm_gelu_fusion
else fc1_output_quantizer # fused gelu output is in fp8
), ),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=( bias=(
...@@ -376,6 +399,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -376,6 +399,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout, ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
) )
if not is_grad_enabled and (ln_out_total is not ln_out_return): if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
...@@ -389,6 +413,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -389,6 +413,10 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias)
elif gemm_gelu_fusion: elif gemm_gelu_fusion:
act_out, _, fc1_out, _ = fc1_outputs act_out, _, fc1_out, _ = fc1_outputs
elif debug:
fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, None)
act_out = fc2_input_quantizer(act_out)
else: else:
fc1_out, *_ = fc1_outputs fc1_out, *_ = fc1_outputs
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
...@@ -426,7 +454,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -426,7 +454,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=fc2_bias, bias=fc2_bias,
quantization_params=output_quantizer, quantization_params=fc2_output_quantizer,
out=fc2_out, out=fc2_out,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub=ub_obj_fc2out, ub=ub_obj_fc2out,
...@@ -515,11 +543,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -515,11 +543,14 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
ctx.fc2_input_quantizer = fc2_input_quantizer ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer
ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer
ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer
ctx.fc1_input_quantizer = fc1_input_quantizer ctx.fc1_input_quantizer = fc1_input_quantizer
ctx.fc2_input_quantizer = fc2_input_quantizer
ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc1_weight_requires_grad = fc1_weight.requires_grad
ctx.fc2_weight_requires_grad = fc2_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad
...@@ -552,6 +583,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -552,6 +583,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_overlap_ag = ub_overlap_ag
ctx.debug = debug
ctx.requires_dgrad = ( ctx.requires_dgrad = (
inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad
...@@ -675,18 +707,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -675,18 +707,18 @@ class _LayerNormMLP(torch.autograd.Function):
# Configure quantizer for FC2 grad output tensor # Configure quantizer for FC2 grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.grad_fc2_output_quantizer is not None: if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True rowwise_usage = True
columnwise_usage = True columnwise_usage = True
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag and isinstance(
ctx.grad_fc2_output_quantizer, ctx.fc2_grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer), (Float8Quantizer, Float8CurrentScalingQuantizer),
): ):
# If data is in FP8 and communication is handled # If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes # with Userbuffers, we compute FP8 transposes
# manually # manually
columnwise_usage = False columnwise_usage = False
ctx.grad_fc2_output_quantizer.set_usage( ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage, rowwise=rowwise_usage,
columnwise=columnwise_usage, columnwise=columnwise_usage,
) )
...@@ -701,7 +733,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -701,7 +733,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad_output, grad_output,
fc2_bias_grad, fc2_bias_grad,
) = TransformerEngineBaseModule.grad_output_preprocess( ) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer
) )
# Launch tensor-parallel communication for FC1 GEMM input # Launch tensor-parallel communication for FC1 GEMM input
...@@ -714,7 +746,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -714,7 +746,7 @@ class _LayerNormMLP(torch.autograd.Function):
and not ctx.ub_bulk_dgrad and not ctx.ub_bulk_dgrad
): ):
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -747,7 +779,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -747,7 +779,10 @@ class _LayerNormMLP(torch.autograd.Function):
# 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm
# 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm
fc2_dgrad_gemm_gelu_fusion = ( fc2_dgrad_gemm_gelu_fusion = (
not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) not ctx.fp8
and (ctx.activation == "gelu")
and (not ctx.bias_gelu_fusion)
and (not ctx.debug)
) )
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
...@@ -763,7 +798,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -763,7 +798,9 @@ class _LayerNormMLP(torch.autograd.Function):
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=( quantization_params=(
ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None ctx.fc1_grad_input_quantizer
if fc2_dgrad_gemm_gelu_fusion or ctx.debug
else None
), # high precision to activation ), # high precision to activation
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion, gelu=fc2_dgrad_gemm_gelu_fusion,
...@@ -798,7 +835,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -798,7 +835,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
quantization_params=None, # wgrad in high precision quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT", layout="NT",
grad=grad_arg, grad=grad_arg,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
...@@ -817,15 +854,20 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -817,15 +854,20 @@ class _LayerNormMLP(torch.autograd.Function):
# bias computation # bias computation
fc1_bias_grad = None fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False fuse_gemm_and_bias_fc1_wgrad = False
if ctx.grad_fc1_output_quantizer is not None: if ctx.fc1_grad_output_quantizer is not None:
ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.bias_gelu_fusion: if ctx.bias_gelu_fusion:
# Fusion: gemm, bias + gelu # Fusion: gemm, bias + gelu
assert ctx.activation == "gelu" assert ctx.activation == "gelu"
assert not ctx.fp8 assert not ctx.fp8
fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias)
if ctx.grad_fc1_output_quantizer is not None: if ctx.fc1_grad_output_quantizer is not None:
dact = ctx.grad_fc1_output_quantizer(dact) dact = ctx.fc1_grad_output_quantizer(dact)
elif ctx.debug:
dact_func = _act_func(ctx.activation)[1]
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None)
fc1_bias_grad = dact.sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
elif ( elif (
_act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None
and ctx.fp8 and ctx.fp8
...@@ -835,7 +877,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -835,7 +877,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2] )[2]
fc1_bias_grad, dact = dbias_dact_quantize_func( fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer
) # quantize bgrad gelu fused ) # quantize bgrad gelu fused
else: else:
# Fusion: gemm + gelu, # Fusion: gemm + gelu,
...@@ -849,12 +891,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -849,12 +891,12 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
# TODO float8 blockwise current scaling has no bgrad fusion for now # TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.grad_fc1_output_quantizer(dact) dact = ctx.fc1_grad_output_quantizer(dact)
else: else:
fc1_bias_grad, dact = tex.bgrad_quantize( fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.grad_fc1_output_quantizer dact, ctx.fc1_grad_output_quantizer
) )
else: else:
fuse_gemm_and_bias_fc1_wgrad = ( fuse_gemm_and_bias_fc1_wgrad = (
...@@ -915,6 +957,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -915,6 +957,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
out=fc1_dgrad_bulk, out=fc1_dgrad_bulk,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN", layout="NN",
grad=True, grad=True,
ub=ub_obj_fc1_dgrad, ub=ub_obj_fc1_dgrad,
...@@ -990,6 +1033,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -990,6 +1033,7 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype else ctx.activation_dtype
), ),
layout="NT", layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad, grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
...@@ -1123,14 +1167,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1123,14 +1167,18 @@ class _LayerNormMLP(torch.autograd.Function):
None, # fp8 None, # fp8
None, # fp8_calibration None, # fp8_calibration
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer None, # fc1_input_quantizer,
None, # fc1_weight_quantizer None, # fc1_weight_quantizer,
None, # fc2_input_quantizer None, # fc1_output_quantizer,
None, # fc2_weight_quantizer None, # fc1_grad_input_quantizer,
None, # output_quantizer None, # fc1_grad_weight_quantizer,
None, # grad_fc2_output_quantizer None, # fc1_grad_output_quantizer,
None, # grad_fc1_output_quantizer None, # fc2_input_quantizer,
None, # grad_input_quantizer None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading None, # cpu_offloading
None, # tp_group None, # tp_group
None, # tp_size None, # tp_size
...@@ -1156,6 +1204,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1156,6 +1204,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # fsdp_group None, # fsdp_group
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # debug
) )
...@@ -1208,6 +1257,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1208,6 +1257,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -1277,6 +1328,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1277,6 +1328,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
name: str = None,
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False, ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
...@@ -1306,6 +1358,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1306,6 +1358,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
and self.activation == "gelu" and self.activation == "gelu"
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
) )
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1466,7 +1522,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1466,7 +1522,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a feedforward network (MLP Block). Apply layer normalization to the input followed by a feedforward network (MLP Block).
...@@ -1489,6 +1547,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1489,6 +1547,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1503,17 +1564,35 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1503,17 +1564,35 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output = True fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp: with self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = (
self._get_quantizers(fp8_output)
if not debug
else self._get_debug_quantizers(fp8_output)
)
if debug:
if not any_feature_enabled(quantizers):
quantizers = self._get_quantizers(fp8_output)
debug = False
if isinstance(self.fc1_weight, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
# Get quantizers # Get quantizers
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, fc2_output_quantizer,
grad_fc1_output_quantizer, fc2_grad_input_quantizer,
grad_fc2_output_quantizer, fc2_grad_weight_quantizer,
grad_input_quantizer, fc2_grad_output_quantizer,
) = self._get_quantizers(fp8_output) ) = quantizers
# Get weight tensors # Get weight tensors
fc1_weight = self.fc1_weight fc1_weight = self.fc1_weight
...@@ -1551,12 +1630,16 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1551,12 +1630,16 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, fc2_output_quantizer,
grad_input_quantizer, fc2_grad_input_quantizer,
grad_fc1_output_quantizer, fc2_grad_weight_quantizer,
grad_fc2_output_quantizer, fc2_grad_output_quantizer,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
...@@ -1565,7 +1648,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1565,7 +1648,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8, self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode, self.set_parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
...@@ -1578,10 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1578,10 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad, self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.gemm_gelu_fusion, self.gemm_gelu_fusion and not debug,
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
debug,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -1603,13 +1687,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1603,13 +1687,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, fc2_output_quantizer,
grad_fc1_output_quantizer, fc2_grad_input_quantizer,
grad_fc2_output_quantizer, fc2_grad_weight_quantizer,
grad_input_quantizer, fc2_grad_output_quantizer,
) = [None] * 8 ) = [None] * 12
if self.fp8: if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary fc1_input_quantizer.internal = False # temporary
...@@ -1623,30 +1711,54 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1623,30 +1711,54 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
] ]
grad_fc2_output_quantizer.internal = True fc2_grad_output_quantizer.internal = True
grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_INPUT1
] ]
grad_fc1_output_quantizer.internal = True fc1_grad_output_quantizer.internal = True
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2]
grad_input_quantizer.internal = True
return ( return (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, fc2_output_quantizer,
grad_fc1_output_quantizer, fc2_grad_input_quantizer,
grad_fc2_output_quantizer, fc2_grad_weight_quantizer,
grad_input_quantizer, fc2_grad_output_quantizer,
) )
def _get_debug_quantizers(self, fp8_output):
from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output))
assert TEDebugState.debug_enabled
def make_debug(prefix, offset):
labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return [
DebugQuantizer(
f"{self.name}.{prefix}",
label,
None if label in ("dgrad", "wgrad") else base_quantizers[i + offset],
self.tp_group,
)
for i, label in enumerate(labels)
]
return tuple(make_debug("fc1", 0) + make_debug("fc2", 6))
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp.""" """Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert ( assert (
...@@ -1691,14 +1803,14 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1691,14 +1803,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
else: else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_INPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
...@@ -1706,7 +1818,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1706,7 +1818,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_INPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode: if self.sequence_parallel and self.set_parallel_mode:
# grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True ].with_amax_reduction = True
......
...@@ -28,11 +28,12 @@ from ..utils import ( ...@@ -28,11 +28,12 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
divide, divide,
init_method_constant, init_method_constant,
requires_grad,
needs_quantized_gemm,
non_tn_fp8_gemm_supported, non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
requires_grad,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -62,6 +63,8 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -62,6 +63,8 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -84,8 +87,9 @@ class _Linear(torch.autograd.Function): ...@@ -84,8 +87,9 @@ class _Linear(torch.autograd.Function):
input_quantizer: Optional[Quantizer], input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
...@@ -106,6 +110,7 @@ class _Linear(torch.autograd.Function): ...@@ -106,6 +110,7 @@ class _Linear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -144,7 +149,7 @@ class _Linear(torch.autograd.Function): ...@@ -144,7 +149,7 @@ class _Linear(torch.autograd.Function):
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling" " current scaling"
) )
if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl: if with_input_all_gather_nccl:
...@@ -196,9 +201,9 @@ class _Linear(torch.autograd.Function): ...@@ -196,9 +201,9 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.input_cast_comm") nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
if not fp8: weightmat = weight
weightmat = cast_if_needed(weight, activation_dtype)
else: if fp8 or debug:
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
columnwise_usage = is_grad_enabled and inp.requires_grad columnwise_usage = is_grad_enabled and inp.requires_grad
...@@ -208,7 +213,6 @@ class _Linear(torch.autograd.Function): ...@@ -208,7 +213,6 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
...@@ -218,11 +222,14 @@ class _Linear(torch.autograd.Function): ...@@ -218,11 +222,14 @@ class _Linear(torch.autograd.Function):
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
) )
else:
weightmat = cast_if_needed(weightmat, activation_dtype)
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32: if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
...@@ -343,12 +350,14 @@ class _Linear(torch.autograd.Function): ...@@ -343,12 +350,14 @@ class _Linear(torch.autograd.Function):
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad ctx.main_grad = weight.main_grad
ctx.debug = debug
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
...@@ -528,7 +537,7 @@ class _Linear(torch.autograd.Function): ...@@ -528,7 +537,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -564,7 +573,6 @@ class _Linear(torch.autograd.Function): ...@@ -564,7 +573,6 @@ class _Linear(torch.autograd.Function):
# Update quantizer # Update quantizer
if ctx.grad_input_quantizer is not None: if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM # dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
...@@ -678,6 +686,7 @@ class _Linear(torch.autograd.Function): ...@@ -678,6 +686,7 @@ class _Linear(torch.autograd.Function):
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
extra_output=rs_out, extra_output=rs_out,
...@@ -753,8 +762,9 @@ class _Linear(torch.autograd.Function): ...@@ -753,8 +762,9 @@ class _Linear(torch.autograd.Function):
None, # input_quantizer None, # input_quantizer
None, # weight_quantizer None, # weight_quantizer
None, # output_quantizer None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # cpu_offloading None, # cpu_offloading
None, # tp_group None, # tp_group
...@@ -775,6 +785,7 @@ class _Linear(torch.autograd.Function): ...@@ -775,6 +785,7 @@ class _Linear(torch.autograd.Function):
None, # fsdp_group None, # fsdp_group
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # debug
) )
...@@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -871,6 +884,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -871,6 +884,7 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -883,6 +897,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -883,6 +897,10 @@ class Linear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if device == "meta": if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device." assert parameters_split is None, "Cannot split module parameters on 'meta' device."
...@@ -1126,6 +1144,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -1126,6 +1144,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else: else:
...@@ -1161,13 +1183,28 @@ class Linear(TransformerEngineBaseModule): ...@@ -1161,13 +1183,28 @@ class Linear(TransformerEngineBaseModule):
else: else:
bias_tensor = None bias_tensor = None
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad) grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
# Make sure weight tensor has correct quantizer # Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization # Note: Quantizer might have changed if quantization
...@@ -1191,8 +1228,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1191,8 +1228,9 @@ class Linear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
...@@ -1213,6 +1251,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1213,6 +1251,7 @@ class Linear(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
debug,
) )
out = linear_fn(*args) out = linear_fn(*args)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
...@@ -1224,8 +1263,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1224,8 +1263,9 @@ class Linear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8: if not self.fp8:
return [None] * 5 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
...@@ -1243,8 +1283,20 @@ class Linear(TransformerEngineBaseModule): ...@@ -1243,8 +1283,20 @@ class Linear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
...@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype): ...@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype):
torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.float = _make_module_cast_func(torch.float32)
torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.half = _make_module_cast_func(torch.float16)
torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16)
def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
all_tensor_types = [
torch.Tensor,
torch.nn.Parameter,
Float8Tensor,
Float8TensorBase,
MXFP8Tensor,
MXFP8TensorBase,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
]
return all_tensor_types
...@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function):
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
dtype = torch_to_transformer_engine_dtype[dtype] te_dtype = torch_to_transformer_engine_dtype[dtype]
# Make sure FP8 data is in expected format # Make sure FP8 data is in expected format
if tensor._data is not None: if tensor._data is not None:
if tensor._data.numel() == 0:
return torch.empty_like(tensor._data, dtype=dtype)
# Cast from FP8 # Cast from FP8
return tex.dequantize(tensor, dtype) return tex.dequantize(tensor, te_dtype)
raise NotImplementedError("Casting back from the transpose not implemented yet!") raise NotImplementedError("Casting back from the transpose not implemented yet!")
......
...@@ -37,7 +37,8 @@ def prepare_for_saving( ...@@ -37,7 +37,8 @@ def prepare_for_saving(
def restore_from_saved( def restore_from_saved(
tensors: list[Optional[Any]], tensors: list[Optional[Any]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
) -> list[Optional[Any]]: return_saved_tensors: bool = False,
) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]:
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
for tensor in tensors: for tensor in tensors:
...@@ -47,6 +48,9 @@ def restore_from_saved( ...@@ -47,6 +48,9 @@ def restore_from_saved(
else: else:
saved_tensors = tensor.restore_from_saved(saved_tensors) saved_tensors = tensor.restore_from_saved(saved_tensors)
tensor_objects.append(tensor) tensor_objects.append(tensor)
if return_saved_tensors:
return tensor_objects, saved_tensors
return tensor_objects return tensor_objects
...@@ -113,7 +117,11 @@ class Quantizer(abc.ABC): ...@@ -113,7 +117,11 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place""" """Quantize tensor in-place"""
def quantize( def quantize(
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Quantize tensor""" """Quantize tensor"""
if out is not None: if out is not None:
......
...@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention import (
MultiheadAttention, MultiheadAttention,
) )
...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type, dist_group_type,
) )
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
...@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module):
head size. Note that these formats are very closely head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module):
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module):
self.attn_input_format = attn_input_format self.attn_input_format = attn_input_format
self.name = name
attention_args = ( attention_args = (
hidden_size, hidden_size,
num_attention_heads, num_attention_heads,
...@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".self_attention" if name is not None else None,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=True, return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".inter_attention" if name is not None else None,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module):
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".layernorm_mlp" if name is not None else None,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)" ), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
......
...@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor from .tensor.quantized_tensor import QuantizedTensor
...@@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple): ...@@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple):
return ((value + multiple - 1) // multiple) * multiple return ((value + multiple - 1) // multiple) * multiple
def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm."""
if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor,
torch.nn.Parameter,
]
return type(obj) not in [
torch.Tensor,
torch.nn.Parameter,
] # pylint: disable=unidiomatic-typecheck
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool: def _nvtx_enabled() -> bool:
"""Check if NVTX range profiling is enabled""" """Check if NVTX range profiling is enabled"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment