Unverified Commit beaecf84 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] NVIDIA-DL-Framework-Inspect support – part 1 – core (#1614)



* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* weight workspace fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* docs fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* file i forgot
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/debug/pytorch/utils.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* setup fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* setup fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* all tensor types
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed check
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* move error
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* _reset
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* name documentation
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added blockwise quantizer
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make debug option optional
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* names fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 0994fb48
...@@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch>=2.1"]) install_reqs.extend(["torch>=2.1"])
install_reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build # Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package for numerical debugging."""
try:
from . import pytorch
from .pytorch.debug_state import set_weight_tensor_tp_group_reduce
except ImportError as e:
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Managing the state of all the debugged layers.
"""
import sys
class TEDebugState:
"""
A class to manage the state of debug layers.
"""
layer_count = 1
layers_initialized = {}
weight_tensor_tp_group_reduce = True
debug_enabled = None
@classmethod
def initialize(cls):
"""
If debug_api module is initialized, then sets cls.debug_enabled to True.
"""
if "nvdlfw_inspect" in sys.modules:
import nvdlfw_inspect.api as debug_api
if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None:
# This method is invoked when initializing TE modules.
# If this error is thrown, it means that some TE module had been initialized before
# debug_api was initialized, and now a new TE module is being initialized.
# This is likely to be a bug.
raise RuntimeError(
"[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before"
" initialization of the first TE module"
)
cls.debug_enabled = debug_api.DEBUG_MANAGER is not None
@classmethod
def _reset(cls):
"""Resets layer count and stats buffers."""
from ..features.utils.stats_buffer import STATS_BUFFERS
STATS_BUFFERS.reset()
cls.debug_enabled = None
cls.layers_initialized.clear()
@classmethod
def get_layer_count(cls):
"""
Layer counter is used when layer names are not provided to modules by the user.
"""
lc = cls.layer_count
cls.layer_count += 1
return lc
@classmethod
def set_weight_tensor_tp_group_reduce(cls, enabled):
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
TEDebugState.set_weight_tensor_tp_group_reduce(enabled)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils functions for the debug module."""
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
return any(q.any_feature_enabled() for q in quantizers)
...@@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion ...@@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
nvtx_range_pop, nvtx_range_pop,
...@@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -6561,6 +6564,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6561,6 +6564,7 @@ class MultiheadAttention(torch.nn.Module):
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -6612,6 +6616,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6612,6 +6616,8 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group, "tp_group": tp_group,
...@@ -6652,6 +6658,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6652,6 +6658,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
name=name + ".layernorm_linear_qkv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -6663,6 +6670,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6663,6 +6670,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=parameters_split, parameters_split=parameters_split,
name=name + ".linear_qkv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
elif self.attention_type == "cross": elif self.attention_type == "cross":
...@@ -6684,6 +6692,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6684,6 +6692,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
name=name + ".layernorm_linear_q" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -6694,6 +6703,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6694,6 +6703,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
name=name + ".linear_q" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
self.key_value = Linear( self.key_value = Linear(
...@@ -6704,6 +6714,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6704,6 +6714,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("key", "value") if not fuse_qkv_params else None, parameters_split=("key", "value") if not fuse_qkv_params else None,
name=name + ".linear_kv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -6733,6 +6744,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6733,6 +6744,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs, ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
ub_name="proj", ub_name="proj",
name=name + ".proj" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -6923,6 +6935,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6923,6 +6935,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# ================================================= # =================================================
# Pre-allocate memory for key-value cache for inference # Pre-allocate memory for key-value cache for inference
# ================================================= # =================================================
......
...@@ -14,6 +14,7 @@ from ..utils import get_sm_count ...@@ -14,6 +14,7 @@ from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
...@@ -109,6 +110,13 @@ def general_gemm( ...@@ -109,6 +110,13 @@ def general_gemm(
if not out.is_contiguous(): if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.") raise ValueError("Output tensor is not contiguous.")
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
quantization_params = quantization_params.parent_quantizer
A = A.get_tensor(not transa)
B = B.get_tensor(transb)
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
...@@ -145,6 +153,9 @@ def general_gemm( ...@@ -145,6 +153,9 @@ def general_gemm(
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
reset_swizzled_inputs(A, B, original_scale_inverses) reset_swizzled_inputs(A, B, original_scale_inverses)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
return out, bias_grad, gelu_input, extra_output return out, bias_grad, gelu_input, extra_output
......
...@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,7 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
...@@ -29,6 +29,7 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer ...@@ -29,6 +29,7 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -1195,6 +1196,28 @@ def gather_along_first_dim( ...@@ -1195,6 +1196,28 @@ def gather_along_first_dim(
out_shape=out_shape, out_shape=out_shape,
) )
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors # High-precision communication for quantized tensors
if quantizer is not None: if quantizer is not None:
warnings.warn( warnings.warn(
......
...@@ -10,6 +10,7 @@ import warnings ...@@ -10,6 +10,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
import logging
from types import MethodType from types import MethodType
import torch import torch
...@@ -39,6 +40,9 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer ...@@ -39,6 +40,9 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub"]
...@@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA." assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.fp8_initialized = False self.fp8_initialized = False
self.fp8 = False self.fp8 = False
self.fp8_calibration = False self.fp8_calibration = False
...@@ -432,6 +437,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -432,6 +437,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
# Names of attributes that can be set quickly (see __setattr__ # Names of attributes that can be set quickly (see __setattr__
# method) # method)
_fast_setattr_names: Set[str] = { _fast_setattr_names: Set[str] = {
...@@ -848,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -848,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# Non-FP8 case: bgrad is fused with wgrad for this case. # Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8 and not ctx.debug:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
...@@ -858,6 +866,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -858,6 +866,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return grad_output, None return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose # FP8 with all-gather: unfused bgrad, fused cast + transpose
# Also supports debug quantization, which is handled inside gather_along_first_dim.
if gather_grad_output: if gather_grad_output:
grad_bias = None grad_bias = None
if ctx.use_bias: if ctx.use_bias:
...@@ -886,6 +895,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -886,6 +895,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
return grad_output, grad_bias return grad_output, grad_bias
# Debug without all-gather: unfused cast and bgrad
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug:
grad_output_ = quantizer(grad_output)
if (
isinstance(
grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias = None
grad_output = grad_output_
return grad_output, grad_bias
# FP8 without all-gather: fused bgrad + cast + transpose # FP8 without all-gather: fused bgrad + cast + transpose
grad_bias = None grad_bias = None
if ctx.use_bias: if ctx.use_bias:
...@@ -1002,6 +1028,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1002,6 +1028,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
update_workspace: bool = True, update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None, skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None, fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values """Get FP8 workspace buffer and maybe update its values
...@@ -1024,6 +1051,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1024,6 +1051,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
over `update_workspace` if provided. over `update_workspace` if provided.
fsdp_group: bool, default = None fsdp_group: bool, default = None
FSDP process group that the weights are distributed over. FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
""" """
# FP8 primary weights # FP8 primary weights
...@@ -1037,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1037,6 +1067,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache # Try getting workspace from cache
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase): if quantizer is not None and isinstance(out, MXFP8TensorBase):
...@@ -1047,6 +1078,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1047,6 +1078,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out = None out = None
del self._fp8_workspaces[cache_name] del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights. # for models initialized with Fp8 primary weights.
...@@ -1064,7 +1100,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1064,7 +1100,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError( raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace" "tensor and quantizer kwargs must be provided to construct FP8 workspace"
) )
out = quantizer(tensor) out = quantizer.quantize(tensor, dtype=workspace_dtype)
# Update cache # Update cache
if cache_name is not None: if cache_name is not None:
...@@ -1081,7 +1117,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1081,7 +1117,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out.quantize_(tensor, noop_flag=skip_update_flag) out.quantize_(tensor, noop_flag=skip_update_flag)
else: else:
tex.quantize(tensor, quantizer, out, skip_update_flag) tex.quantize(tensor, quantizer, out, skip_update_flag)
return out return out
def _load_from_state_dict( def _load_from_state_dict(
...@@ -1104,3 +1139,47 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1104,3 +1139,47 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super()._load_from_state_dict( super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )
def _validate_name(self):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
"""
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
...@@ -35,6 +35,7 @@ from ..utils import ( ...@@ -35,6 +35,7 @@ from ..utils import (
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
needs_quantized_gemm,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -56,6 +57,8 @@ from ..tensor.quantized_tensor import ( ...@@ -56,6 +57,8 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -90,8 +93,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -90,8 +93,9 @@ class _LayerNormLinear(torch.autograd.Function):
input_quantizer: Optional[Quantizer], input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
...@@ -116,6 +120,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -116,6 +120,7 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -214,12 +219,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -214,12 +219,12 @@ class _LayerNormLinear(torch.autograd.Function):
# norm output will be returned # norm output will be returned
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8: if fp8 or debug:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
if fp8: if fp8 or debug:
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -233,18 +238,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -233,18 +238,19 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(input_quantizer if fp8 else None), quantizer=(input_quantizer if fp8 or debug else None),
) )
else: else:
if fp8 and not with_quantized_norm: if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
if not fp8: weightmat = weight
quantized_weight = False quantized_weight = False
weightmat = cast_if_needed(weight, activation_dtype) if not fp8 and not debug:
weightmat = cast_if_needed(weightmat, activation_dtype)
else: else:
quantized_weight = not isinstance(weight, QuantizedTensor) quantized_weight = not isinstance(weight, QuantizedTensor)
...@@ -254,6 +260,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -254,6 +260,7 @@ class _LayerNormLinear(torch.autograd.Function):
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
tensor=weight, tensor=weight,
quantizer=weight_quantizer, quantizer=weight_quantizer,
...@@ -261,11 +268,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -261,11 +268,12 @@ class _LayerNormLinear(torch.autograd.Function):
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
) )
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
...@@ -400,6 +408,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -400,6 +408,7 @@ class _LayerNormLinear(torch.autograd.Function):
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad ctx.main_grad = weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.owns_input = inputmat is not inp ctx.owns_input = inputmat is not inp
...@@ -434,6 +443,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -434,6 +443,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.debug = debug
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
...@@ -611,7 +621,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -611,7 +621,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None quantizer = None
if ctx.fp8: if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -757,6 +767,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -757,6 +767,7 @@ class _LayerNormLinear(torch.autograd.Function):
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
extra_output=rs_out, extra_output=rs_out,
...@@ -865,8 +876,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -865,8 +876,9 @@ class _LayerNormLinear(torch.autograd.Function):
None, # input_quantizer None, # input_quantizer
None, # weight_quantizer None, # weight_quantizer
None, # output_quantizer None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading None, # cpu_offloading
None, # tp_group None, # tp_group
None, # tp_size None, # tp_size
...@@ -889,6 +901,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -889,6 +901,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_bulk_wgrad None, # ub_bulk_wgrad
None, # ub_name None, # ub_name
None, # fsdp_group None, # fsdp_group
None, # debug
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
) )
...@@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -1007,6 +1022,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1007,6 +1022,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1023,6 +1039,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1023,6 +1039,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output_gathered = return_layernorm_output_gathered self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1312,6 +1332,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1312,6 +1332,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1348,13 +1371,28 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1348,13 +1371,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = getattr(self, self.bias_names[0]) # Unused
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad) grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
...@@ -1376,8 +1414,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1376,8 +1414,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
...@@ -1402,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1402,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
debug,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -1421,8 +1461,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1421,8 +1461,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8: if not self.fp8:
return [None] * 5 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
...@@ -1441,8 +1482,20 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1441,8 +1482,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
...@@ -28,11 +28,12 @@ from ..utils import ( ...@@ -28,11 +28,12 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
divide, divide,
init_method_constant, init_method_constant,
requires_grad,
needs_quantized_gemm,
non_tn_fp8_gemm_supported, non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
requires_grad,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -62,6 +63,8 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -62,6 +63,8 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -84,8 +87,9 @@ class _Linear(torch.autograd.Function): ...@@ -84,8 +87,9 @@ class _Linear(torch.autograd.Function):
input_quantizer: Optional[Quantizer], input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
...@@ -106,6 +110,7 @@ class _Linear(torch.autograd.Function): ...@@ -106,6 +110,7 @@ class _Linear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
debug: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -144,7 +149,7 @@ class _Linear(torch.autograd.Function): ...@@ -144,7 +149,7 @@ class _Linear(torch.autograd.Function):
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling" " current scaling"
) )
if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl: if with_input_all_gather_nccl:
...@@ -196,9 +201,9 @@ class _Linear(torch.autograd.Function): ...@@ -196,9 +201,9 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.input_cast_comm") nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype # Cast weight to expected dtype
if not fp8: weightmat = weight
weightmat = cast_if_needed(weight, activation_dtype)
else: if fp8 or debug:
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
columnwise_usage = is_grad_enabled and inp.requires_grad columnwise_usage = is_grad_enabled and inp.requires_grad
...@@ -208,7 +213,6 @@ class _Linear(torch.autograd.Function): ...@@ -208,7 +213,6 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
...@@ -218,11 +222,14 @@ class _Linear(torch.autograd.Function): ...@@ -218,11 +222,14 @@ class _Linear(torch.autograd.Function):
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
) )
else:
weightmat = cast_if_needed(weightmat, activation_dtype)
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32: if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
...@@ -343,12 +350,14 @@ class _Linear(torch.autograd.Function): ...@@ -343,12 +350,14 @@ class _Linear(torch.autograd.Function):
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad ctx.main_grad = weight.main_grad
ctx.debug = debug
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
...@@ -528,7 +537,7 @@ class _Linear(torch.autograd.Function): ...@@ -528,7 +537,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
quantizer = None quantizer = None
if ctx.fp8: if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -564,7 +573,6 @@ class _Linear(torch.autograd.Function): ...@@ -564,7 +573,6 @@ class _Linear(torch.autograd.Function):
# Update quantizer # Update quantizer
if ctx.grad_input_quantizer is not None: if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM # dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
...@@ -678,6 +686,7 @@ class _Linear(torch.autograd.Function): ...@@ -678,6 +686,7 @@ class _Linear(torch.autograd.Function):
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
extra_output=rs_out, extra_output=rs_out,
...@@ -753,8 +762,9 @@ class _Linear(torch.autograd.Function): ...@@ -753,8 +762,9 @@ class _Linear(torch.autograd.Function):
None, # input_quantizer None, # input_quantizer
None, # weight_quantizer None, # weight_quantizer
None, # output_quantizer None, # output_quantizer
None, # grad_output_quantizer
None, # grad_input_quantizer None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # cpu_offloading None, # cpu_offloading
None, # tp_group None, # tp_group
...@@ -775,6 +785,7 @@ class _Linear(torch.autograd.Function): ...@@ -775,6 +785,7 @@ class _Linear(torch.autograd.Function):
None, # fsdp_group None, # fsdp_group
None, # module None, # module
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
None, # debug
) )
...@@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -871,6 +884,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -871,6 +884,7 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -883,6 +897,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -883,6 +897,10 @@ class Linear(TransformerEngineBaseModule):
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if device == "meta": if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device." assert parameters_split is None, "Cannot split module parameters on 'meta' device."
...@@ -1126,6 +1144,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -1126,6 +1144,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = TEDebugState.debug_enabled
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else: else:
...@@ -1161,13 +1183,28 @@ class Linear(TransformerEngineBaseModule): ...@@ -1161,13 +1183,28 @@ class Linear(TransformerEngineBaseModule):
else: else:
bias_tensor = None bias_tensor = None
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if not any_feature_enabled(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) = self._get_quantizers(fp8_output, fp8_grad) grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
# Make sure weight tensor has correct quantizer # Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization # Note: Quantizer might have changed if quantization
...@@ -1191,8 +1228,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1191,8 +1228,9 @@ class Linear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
...@@ -1213,6 +1251,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1213,6 +1251,7 @@ class Linear(TransformerEngineBaseModule):
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
debug,
) )
out = linear_fn(*args) out = linear_fn(*args)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
...@@ -1224,8 +1263,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1224,8 +1263,9 @@ class Linear(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8: if not self.fp8:
return [None] * 5 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
...@@ -1243,8 +1283,20 @@ class Linear(TransformerEngineBaseModule): ...@@ -1243,8 +1283,20 @@ class Linear(TransformerEngineBaseModule):
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
for name, q in zip(names, original_quantizers)
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
......
...@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype): ...@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype):
torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.float = _make_module_cast_func(torch.float32)
torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.half = _make_module_cast_func(torch.float16)
torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16)
def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
all_tensor_types = [
torch.Tensor,
torch.nn.Parameter,
Float8Tensor,
Float8TensorBase,
MXFP8Tensor,
MXFP8TensorBase,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
]
return all_tensor_types
...@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function):
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
dtype = torch_to_transformer_engine_dtype[dtype] te_dtype = torch_to_transformer_engine_dtype[dtype]
# Make sure FP8 data is in expected format # Make sure FP8 data is in expected format
if tensor._data is not None: if tensor._data is not None:
if tensor._data.numel() == 0:
return torch.empty_like(tensor._data, dtype=dtype)
# Cast from FP8 # Cast from FP8
return tex.dequantize(tensor, dtype) return tex.dequantize(tensor, te_dtype)
raise NotImplementedError("Casting back from the transpose not implemented yet!") raise NotImplementedError("Casting back from the transpose not implemented yet!")
......
...@@ -37,7 +37,8 @@ def prepare_for_saving( ...@@ -37,7 +37,8 @@ def prepare_for_saving(
def restore_from_saved( def restore_from_saved(
tensors: list[Optional[Any]], tensors: list[Optional[Any]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
) -> list[Optional[Any]]: return_saved_tensors: bool = False,
) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]:
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
for tensor in tensors: for tensor in tensors:
...@@ -47,6 +48,9 @@ def restore_from_saved( ...@@ -47,6 +48,9 @@ def restore_from_saved(
else: else:
saved_tensors = tensor.restore_from_saved(saved_tensors) saved_tensors = tensor.restore_from_saved(saved_tensors)
tensor_objects.append(tensor) tensor_objects.append(tensor)
if return_saved_tensors:
return tensor_objects, saved_tensors
return tensor_objects return tensor_objects
...@@ -113,7 +117,11 @@ class Quantizer(abc.ABC): ...@@ -113,7 +117,11 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place""" """Quantize tensor in-place"""
def quantize( def quantize(
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Quantize tensor""" """Quantize tensor"""
if out is not None: if out is not None:
......
...@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention import (
MultiheadAttention, MultiheadAttention,
) )
...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type, dist_group_type,
) )
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
...@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module):
head size. Note that these formats are very closely head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module):
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module):
self.attn_input_format = attn_input_format self.attn_input_format = attn_input_format
self.name = name
attention_args = ( attention_args = (
hidden_size, hidden_size,
num_attention_heads, num_attention_heads,
...@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".self_attention" if name is not None else None,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=True, return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".inter_attention" if name is not None else None,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module):
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".layernorm_mlp" if name is not None else None,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)" ), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
......
...@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor from .tensor.quantized_tensor import QuantizedTensor
...@@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple): ...@@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple):
return ((value + multiple - 1) // multiple) * multiple return ((value + multiple - 1) // multiple) * multiple
def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm."""
if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor,
torch.nn.Parameter,
]
return type(obj) not in [
torch.Tensor,
torch.nn.Parameter,
] # pylint: disable=unidiomatic-typecheck
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool: def _nvtx_enabled() -> bool:
"""Check if NVTX range profiling is enabled""" """Check if NVTX range profiling is enabled"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment