Unverified Commit 9f61f8a5 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] Debug support for GroupedLinear (#1953)



* main
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* docs
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* test fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b3c25057
...@@ -107,6 +107,8 @@ The ``TransformerLayer`` in Transformer Engine is a composition of multiple sub- ...@@ -107,6 +107,8 @@ The ``TransformerLayer`` in Transformer Engine is a composition of multiple sub-
depending on the configuration. Some layers, like ``LayerNormLinear``, are fusions of two layers: ``LayerNorm`` and ``Linear``. When referring to such layers in precision debug tools, only the ``Linear`` part is affected. depending on the configuration. Some layers, like ``LayerNormLinear``, are fusions of two layers: ``LayerNorm`` and ``Linear``. When referring to such layers in precision debug tools, only the ``Linear`` part is affected.
For `GroupedLinear` layer, the names of underlying GEMMS are of the form `layer_name.gemm_n`, where `n` is the index of the GEMM.
Below is an example ``TransformerLayer`` with four linear layers that can be influenced by the precision debug tools. Below is an example ``TransformerLayer`` with four linear layers that can be influenced by the precision debug tools.
.. figure:: ./img/names.svg .. figure:: ./img/names.svg
......
...@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): ...@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset() TEDebugState._reset()
def test_log_grouped_gemm(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.GroupedLinear(3, 128, 128, name="linear1", params_dtype=torch.bfloat16)
inp = torch.randn((1, 128, 128), dtype=torch.bfloat16).cuda()
m_splits = [64, 32, 32]
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp, m_splits=m_splits)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert "gemm_0" in output, "gemm0 not found in output"
assert "gemm_1" in output, "gemm1 not found in output"
assert "gemm_2" in output, "gemm2 not found in output"
def test_compute_max_blockwise_dynamic_range_direct(): def test_compute_max_blockwise_dynamic_range_direct():
"""Direct unit test for compute_max_blockwise_dynamic_range function. """Direct unit test for compute_max_blockwise_dynamic_range function.
......
...@@ -1270,6 +1270,9 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): ...@@ -1270,6 +1270,9 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation): def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
te_linear_ref = Linear( te_linear_ref = Linear(
...@@ -1566,6 +1569,9 @@ def test_layernorm_linear_accuracy( ...@@ -1566,6 +1569,9 @@ def test_layernorm_linear_accuracy(
def test_layernorm_linear_accuracy_delay_wgrad_compute( def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
): ):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ln_linear_ref = LayerNormLinear( ln_linear_ref = LayerNormLinear(
...@@ -1705,6 +1711,9 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1705,6 +1711,9 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
bias, bias,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
): ):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ln_mlp = LayerNormMLP( ln_mlp = LayerNormMLP(
...@@ -1899,6 +1908,8 @@ def test_grouped_linear_accuracy( ...@@ -1899,6 +1908,8 @@ def test_grouped_linear_accuracy(
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2041,6 +2052,8 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -2041,6 +2052,8 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2757,6 +2770,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2757,6 +2770,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
A, A,
B, B,
out, out,
[None] * z,
dtype, dtype,
m_splits=m_splits, m_splits=m_splits,
grad=grad, grad=grad,
...@@ -2904,6 +2918,7 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2904,6 +2918,7 @@ def test_fp8_grouped_gemm(shape, accumulate):
A_fp8, A_fp8,
B_fp8, B_fp8,
out, out,
[None] * z,
dtype, dtype,
m_splits=m_splits, m_splits=m_splits,
accumulate=accumulate, accumulate=accumulate,
......
...@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package. ...@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union from typing import Optional, Tuple, Iterable, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -556,6 +556,23 @@ class DebugQuantizer(Quantizer): ...@@ -556,6 +556,23 @@ class DebugQuantizer(Quantizer):
if not self.output_tensor: if not self.output_tensor:
self._update_parent_quantizer_usage() self._update_parent_quantizer_usage()
@classmethod
def multi_tensor_quantize(
cls,
tensor: torch.Tensor,
quantizers: List[Quantizer],
m_splits: List[int],
activation_dtype: torch.dtype,
) -> List[DebugQuantizedTensor]:
"""
Splits a tensor into a list of tensors and quantizes each tensor using a list of quantizers.
"""
tensors = torch.split(tensor, m_splits)
output = []
for tensor, quantizer in zip(tensors, quantizers):
output.append(quantizer.quantize(tensor, dtype=activation_dtype))
return output
class DebugQuantizedTensor(QuantizedTensorStorage): class DebugQuantizedTensor(QuantizedTensorStorage):
""" """
...@@ -623,9 +640,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage): ...@@ -623,9 +640,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage):
"""Is used in the python gemm() to get tensor or transpose of the tensor.""" """Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self): def size(self, *args):
"""Size of the tensor.""" """Size of the tensor."""
return self.rowwise_gemm_tensor.size() return self.rowwise_gemm_tensor.size(*args)
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor.""" """Update usage of the tensor."""
......
...@@ -114,7 +114,6 @@ def general_gemm( ...@@ -114,7 +114,6 @@ def general_gemm(
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
# assert quantization_params is None, "FP8 output not supported yet"
alpha = validate_gemm_scale(alpha, True) alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate) beta = validate_gemm_scale(beta, accumulate)
...@@ -215,6 +214,7 @@ def general_grouped_gemm( ...@@ -215,6 +214,7 @@ def general_grouped_gemm(
A: List[torch.Tensor], A: List[torch.Tensor],
B: List[torch.Tensor], B: List[torch.Tensor],
out: List[torch.Tensor], out: List[torch.Tensor],
quantization_params: List[Optional[Quantizer]],
out_dtype: torch.dtype, out_dtype: torch.dtype,
layout: str = "TN", layout: str = "TN",
m_splits: Optional[List[int]] = None, m_splits: Optional[List[int]] = None,
...@@ -247,7 +247,7 @@ def general_grouped_gemm( ...@@ -247,7 +247,7 @@ def general_grouped_gemm(
if grad and use_bias: if grad and use_bias:
grad_bias = [ grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) torch.empty(B[i].size(1), dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
] ]
else: else:
grad_bias = empty_tensors grad_bias = empty_tensors
...@@ -257,6 +257,36 @@ def general_grouped_gemm( ...@@ -257,6 +257,36 @@ def general_grouped_gemm(
else: else:
bias_dtype = TE_DType[torch.bfloat16] bias_dtype = TE_DType[torch.bfloat16]
if isinstance(quantization_params[0], DebugQuantizer):
assert not gelu, "GELU not supported in debug mode"
if single_output:
out_init = out[0]
start_idx = 0
out = [None] * num_gemms
for i in range(num_gemms):
size = m_splits[i]
out[i] = out_init[start_idx : start_idx + size]
start_idx += size
for i in range(num_gemms):
_, bias_or_grad, _, _ = general_gemm(
A[i],
B[i],
quantization_params=quantization_params[i],
out_dtype=out[0].dtype,
layout=layout,
accumulate=accumulate,
out=out[i],
bias=bias[i] if use_bias else None,
use_split_accumulator=use_split_accumulator,
grad=grad,
)
if grad and use_bias:
grad_bias[i] = bias_or_grad
if single_output:
out = out_init
return out, grad_bias if grad else bias, None
if gelu: if gelu:
gelu_input = [ gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
......
...@@ -1165,18 +1165,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1165,18 +1165,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug: if ctx.debug:
grad_output_ = quantizer(grad_output) grad_output_ = quantizer(grad_output)
if ( if ctx.use_bias:
isinstance(
grad_output_.get_tensor(True),
(
QuantizedTensor,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
)
and ctx.use_bias
):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
grad_bias = None grad_bias = None
...@@ -1540,6 +1529,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1540,6 +1529,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# we use the debug value from the first invocation in the iteration. # we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration debug = self.debug_enabled_in_this_iteration
self.debug_last_iteration = TEDebugState.get_iteration()
if self.wgrad_store is not None:
if debug and self.wgrad_store.delay_wgrad_compute():
raise RuntimeError("Delayed wgrad compute is not supported in debug mode.")
return debug return debug
def no_debug_features_active(self, quantizers): def no_debug_features_active(self, quantizers):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings import warnings
import functools import functools
...@@ -49,6 +50,8 @@ from ..quantized_tensor import ( ...@@ -49,6 +50,8 @@ from ..quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -58,6 +61,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -58,6 +61,7 @@ class _GroupedLinear(torch.autograd.Function):
Calls custom cuda extensions. Calls custom cuda extensions.
""" """
# pylint: disable=keyword-arg-before-vararg
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
...@@ -79,6 +83,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -79,6 +83,8 @@ class _GroupedLinear(torch.autograd.Function):
input_quantizers, input_quantizers,
weight_quantizers, weight_quantizers,
output_quantizers, output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers, grad_output_quantizers,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
cpu_offloading, cpu_offloading,
...@@ -88,6 +94,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -88,6 +94,7 @@ class _GroupedLinear(torch.autograd.Function):
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
save_original_input, save_original_input,
debug,
) = non_tensor_args ) = non_tensor_args
num_gemms = len(m_splits) num_gemms = len(m_splits)
...@@ -135,8 +142,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -135,8 +142,12 @@ class _GroupedLinear(torch.autograd.Function):
) )
inp_view = inp.reshape(-1, in_features) inp_view = inp.reshape(-1, in_features)
inputmats: list inputmats: list
if fp8: if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
)
else: else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
...@@ -145,7 +156,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -145,7 +156,7 @@ class _GroupedLinear(torch.autograd.Function):
# Initialize weights # Initialize weights
weights_fp8: list weights_fp8: list
if fp8: if fp8 or debug:
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
weights_fp8 = [] weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
...@@ -156,6 +167,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -156,6 +167,7 @@ class _GroupedLinear(torch.autograd.Function):
cache_name=(None if is_first_microbatch is None else f"weight{i}"), cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
) )
weights_fp8.append(weight_fp8) weights_fp8.append(weight_fp8)
...@@ -167,7 +179,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -167,7 +179,6 @@ class _GroupedLinear(torch.autograd.Function):
if fp8 and activation_dtype == torch.float32: if fp8 and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Initialize output tensor # Initialize output tensor
out = torch.empty( out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)], [sum(m_splits), weights_fp8[0].size(0)],
...@@ -183,10 +194,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -183,10 +194,11 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Perform GEMM # Perform GEMM
_ = general_grouped_gemm( general_grouped_gemm(
weights_fp8, weights_fp8,
inputmats, inputmats,
[out], [out],
output_quantizers,
activation_dtype, activation_dtype,
single_output=True, single_output=True,
m_splits=m_splits, m_splits=m_splits,
...@@ -244,6 +256,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -244,6 +256,10 @@ class _GroupedLinear(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.grad_input_quantizers = grad_input_quantizers
ctx.grad_output_quantizers = grad_output_quantizers
ctx.grad_weight_quantizers = grad_weight_quantizers
ctx.weights_requires_grad = weights[0].requires_grad ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad: if fuse_wgrad_accumulation and ctx.weights_requires_grad:
# This check is needed to ensure that main_grad is not created # This check is needed to ensure that main_grad is not created
...@@ -259,7 +275,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -259,7 +275,7 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers ctx.output_quantizers = output_quantizers
ctx.m_splits = m_splits ctx.m_splits = m_splits
ctx.num_gemms = num_gemms ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -279,6 +295,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -279,6 +295,7 @@ class _GroupedLinear(torch.autograd.Function):
or FP8GlobalStateManager.is_first_fp8_module() or FP8GlobalStateManager.is_first_fp8_module()
) )
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
ctx.debug = debug
ctx.save_original_input = save_original_input ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers ctx.input_quantizers = input_quantizers
...@@ -311,7 +328,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -311,7 +328,7 @@ class _GroupedLinear(torch.autograd.Function):
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.fp8: if ctx.fp8 and not ctx.debug:
if ctx.use_bias: if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits) grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
...@@ -338,6 +355,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -338,6 +355,13 @@ class _GroupedLinear(torch.autograd.Function):
ctx.m_splits, ctx.m_splits,
ctx.grad_output_quantizers, ctx.grad_output_quantizers,
) )
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype
)
else: else:
# Only split grad output. Grad bias is fused with # Only split grad output. Grad bias is fused with
# wgrad GEMM. # wgrad GEMM.
...@@ -355,7 +379,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -355,7 +379,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8: if ctx.fp8 or ctx.debug:
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = ( dgrad_gemm_use_split_accumulator = (
...@@ -375,6 +399,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -375,6 +399,7 @@ class _GroupedLinear(torch.autograd.Function):
weights, weights,
grad_output, grad_output,
[dgrad], [dgrad],
ctx.grad_input_quantizers,
ctx.activation_dtype, ctx.activation_dtype,
single_output=True, single_output=True,
layout="NN", layout="NN",
...@@ -412,15 +437,19 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -412,15 +437,19 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
input_quantizer.set_usage(rowwise=False, columnwise=True) input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list inputmats: list
if ctx.fp8: if ctx.fp8 and not ctx.debug:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype
)
else: else:
inputmats = torch.split( inputmats = torch.split(
cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits
) )
grouped_gemm_wgrad = functools.partial( grouped_gemm_wgrad = functools.partial(
general_grouped_gemm, general_grouped_gemm,
quantization_params=ctx.grad_weight_quantizers,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
layout="NT", layout="NT",
grad=True, grad=True,
...@@ -576,6 +605,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -576,6 +605,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
save_original_input: bool = False, save_original_input: bool = False,
name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -596,6 +626,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -596,6 +626,7 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
...@@ -745,6 +776,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -745,6 +776,8 @@ class GroupedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
debug = self.is_debug_iter()
assert not isinstance( assert not isinstance(
inp, QuantizedTensorStorage inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
...@@ -756,31 +789,24 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -756,31 +789,24 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = self._get_weight_tensors() weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
weight_quantizers = self._get_weight_quantizers() quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms, if debug:
[None] * self.num_gemms, if self.no_debug_features_active(list(chain(*quantizers))):
) debug = False
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms quantizers = self._get_quantizers()
if self.fp8:
input_quantizers = [ if isinstance(weight_tensors, QuantizedTensorStorage):
self.quantizers["scaling_fwd"][ raise RuntimeError("FP8 weights are not supported in debug mode.")
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
] (
for i in range(self.num_gemms) input_quantizers,
] weight_quantizers,
# TODO: use internal after #1638 is merged. # pylint: disable=fixme output_quantizers,
for i in range(self.num_gemms): grad_input_quantizers,
input_quantizers[i].internal = False grad_weight_quantizers,
if is_grad_enabled: grad_output_quantizers,
grad_output_quantizers = [ ) = quantizers
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
if is_grad_enabled: if is_grad_enabled:
linear_fn = _GroupedLinear.apply linear_fn = _GroupedLinear.apply
...@@ -799,6 +825,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -799,6 +825,8 @@ class GroupedLinear(TransformerEngineBaseModule):
input_quantizers, input_quantizers,
weight_quantizers, weight_quantizers,
output_quantizers, output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers, grad_output_quantizers,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
...@@ -808,6 +836,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -808,6 +836,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self, self,
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
self.save_original_input, self.save_original_input,
debug,
) )
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
...@@ -898,3 +927,55 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -898,3 +927,55 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms): for i in range(self.num_gemms):
weight_quantizers[i].internal = True weight_quantizers[i].internal = True
return weight_quantizers return weight_quantizers
def _get_quantizers(self):
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_input_quantizers, grad_weight_quantizers, grad_output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
[None] * self.num_gemms,
)
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
return (
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
)
def _get_debug_quantizers(self):
original_quantizers = self._get_quantizers()
assert TEDebugState.debug_enabled
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
[
DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group)
for q_id, q in enumerate(qs)
]
for name, qs in zip(names, original_quantizers)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment