Unverified Commit 9f61f8a5 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] Debug support for GroupedLinear (#1953)



* main
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* docs
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* test fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b3c25057
......@@ -107,6 +107,8 @@ The ``TransformerLayer`` in Transformer Engine is a composition of multiple sub-
depending on the configuration. Some layers, like ``LayerNormLinear``, are fusions of two layers: ``LayerNorm`` and ``Linear``. When referring to such layers in precision debug tools, only the ``Linear`` part is affected.
For `GroupedLinear` layer, the names of underlying GEMMS are of the form `layer_name.gemm_n`, where `n` is the index of the GEMM.
Below is an example ``TransformerLayer`` with four linear layers that can be influenced by the precision debug tools.
.. figure:: ./img/names.svg
......
......@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset()
def test_log_grouped_gemm(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.GroupedLinear(3, 128, 128, name="linear1", params_dtype=torch.bfloat16)
inp = torch.randn((1, 128, 128), dtype=torch.bfloat16).cuda()
m_splits = [64, 32, 32]
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp, m_splits=m_splits)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert "gemm_0" in output, "gemm0 not found in output"
assert "gemm_1" in output, "gemm1 not found in output"
assert "gemm_2" in output, "gemm2 not found in output"
def test_compute_max_blockwise_dynamic_range_direct():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
......
......@@ -1270,6 +1270,9 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model]
te_linear_ref = Linear(
......@@ -1566,6 +1569,9 @@ def test_layernorm_linear_accuracy(
def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model]
ln_linear_ref = LayerNormLinear(
......@@ -1705,6 +1711,9 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
bias,
fuse_wgrad_accumulation,
):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model]
ln_mlp = LayerNormMLP(
......@@ -1899,6 +1908,8 @@ def test_grouped_linear_accuracy(
fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -2041,6 +2052,8 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -2757,6 +2770,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
......@@ -2904,6 +2918,7 @@ def test_fp8_grouped_gemm(shape, accumulate):
A_fp8,
B_fp8,
out,
[None] * z,
dtype,
m_splits=m_splits,
accumulate=accumulate,
......
......@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
from typing import Optional, Tuple, Iterable, Union, List
import torch
import transformer_engine_torch as tex
......@@ -556,6 +556,23 @@ class DebugQuantizer(Quantizer):
if not self.output_tensor:
self._update_parent_quantizer_usage()
@classmethod
def multi_tensor_quantize(
cls,
tensor: torch.Tensor,
quantizers: List[Quantizer],
m_splits: List[int],
activation_dtype: torch.dtype,
) -> List[DebugQuantizedTensor]:
"""
Splits a tensor into a list of tensors and quantizes each tensor using a list of quantizers.
"""
tensors = torch.split(tensor, m_splits)
output = []
for tensor, quantizer in zip(tensors, quantizers):
output.append(quantizer.quantize(tensor, dtype=activation_dtype))
return output
class DebugQuantizedTensor(QuantizedTensorStorage):
"""
......@@ -623,9 +640,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self):
def size(self, *args):
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size()
return self.rowwise_gemm_tensor.size(*args)
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor."""
......
......@@ -114,7 +114,6 @@ def general_gemm(
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
# assert quantization_params is None, "FP8 output not supported yet"
alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
......@@ -215,6 +214,7 @@ def general_grouped_gemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
quantization_params: List[Optional[Quantizer]],
out_dtype: torch.dtype,
layout: str = "TN",
m_splits: Optional[List[int]] = None,
......@@ -247,7 +247,7 @@ def general_grouped_gemm(
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
torch.empty(B[i].size(1), dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
......@@ -257,6 +257,36 @@ def general_grouped_gemm(
else:
bias_dtype = TE_DType[torch.bfloat16]
if isinstance(quantization_params[0], DebugQuantizer):
assert not gelu, "GELU not supported in debug mode"
if single_output:
out_init = out[0]
start_idx = 0
out = [None] * num_gemms
for i in range(num_gemms):
size = m_splits[i]
out[i] = out_init[start_idx : start_idx + size]
start_idx += size
for i in range(num_gemms):
_, bias_or_grad, _, _ = general_gemm(
A[i],
B[i],
quantization_params=quantization_params[i],
out_dtype=out[0].dtype,
layout=layout,
accumulate=accumulate,
out=out[i],
bias=bias[i] if use_bias else None,
use_split_accumulator=use_split_accumulator,
grad=grad,
)
if grad and use_bias:
grad_bias[i] = bias_or_grad
if single_output:
out = out_init
return out, grad_bias if grad else bias, None
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
......
......@@ -1165,18 +1165,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if ctx.debug:
grad_output_ = quantizer(grad_output)
if (
isinstance(
grad_output_.get_tensor(True),
(
QuantizedTensor,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
)
and ctx.use_bias
):
if ctx.use_bias:
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias = None
......@@ -1540,6 +1529,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration
self.debug_last_iteration = TEDebugState.get_iteration()
if self.wgrad_store is not None:
if debug and self.wgrad_store.delay_wgrad_compute():
raise RuntimeError("Delayed wgrad compute is not supported in debug mode.")
return debug
def no_debug_features_active(self, quantizers):
......
......@@ -4,6 +4,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings
import functools
......@@ -49,6 +50,8 @@ from ..quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["GroupedLinear"]
......@@ -58,6 +61,7 @@ class _GroupedLinear(torch.autograd.Function):
Calls custom cuda extensions.
"""
# pylint: disable=keyword-arg-before-vararg
@staticmethod
def forward(
ctx,
......@@ -79,6 +83,8 @@ class _GroupedLinear(torch.autograd.Function):
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
fuse_wgrad_accumulation,
cpu_offloading,
......@@ -88,6 +94,7 @@ class _GroupedLinear(torch.autograd.Function):
module,
skip_fp8_weight_update,
save_original_input,
debug,
) = non_tensor_args
num_gemms = len(m_splits)
......@@ -135,8 +142,12 @@ class _GroupedLinear(torch.autograd.Function):
)
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8:
if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
)
else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
......@@ -145,7 +156,7 @@ class _GroupedLinear(torch.autograd.Function):
# Initialize weights
weights_fp8: list
if fp8:
if fp8 or debug:
# FP8 cast to workspace buffer
weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch
......@@ -156,6 +167,7 @@ class _GroupedLinear(torch.autograd.Function):
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
)
weights_fp8.append(weight_fp8)
......@@ -167,7 +179,6 @@ class _GroupedLinear(torch.autograd.Function):
if fp8 and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Initialize output tensor
out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)],
......@@ -183,10 +194,11 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Perform GEMM
_ = general_grouped_gemm(
general_grouped_gemm(
weights_fp8,
inputmats,
[out],
output_quantizers,
activation_dtype,
single_output=True,
m_splits=m_splits,
......@@ -244,6 +256,10 @@ class _GroupedLinear(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.grad_input_quantizers = grad_input_quantizers
ctx.grad_output_quantizers = grad_output_quantizers
ctx.grad_weight_quantizers = grad_weight_quantizers
ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
# This check is needed to ensure that main_grad is not created
......@@ -259,7 +275,7 @@ class _GroupedLinear(torch.autograd.Function):
else:
ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers
ctx.output_quantizers = output_quantizers
ctx.m_splits = m_splits
ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype
......@@ -279,6 +295,7 @@ class _GroupedLinear(torch.autograd.Function):
or FP8GlobalStateManager.is_first_fp8_module()
)
ctx.wgrad_store = wgrad_store
ctx.debug = debug
ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers
......@@ -311,7 +328,7 @@ class _GroupedLinear(torch.autograd.Function):
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.fp8 and not ctx.debug:
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe
......@@ -338,6 +355,13 @@ class _GroupedLinear(torch.autograd.Function):
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype
)
else:
# Only split grad output. Grad bias is fused with
# wgrad GEMM.
......@@ -355,7 +379,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
if ctx.fp8 or ctx.debug:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
......@@ -375,6 +399,7 @@ class _GroupedLinear(torch.autograd.Function):
weights,
grad_output,
[dgrad],
ctx.grad_input_quantizers,
ctx.activation_dtype,
single_output=True,
layout="NN",
......@@ -412,15 +437,19 @@ class _GroupedLinear(torch.autograd.Function):
else:
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8:
if ctx.fp8 and not ctx.debug:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype
)
else:
inputmats = torch.split(
cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits
)
grouped_gemm_wgrad = functools.partial(
general_grouped_gemm,
quantization_params=ctx.grad_weight_quantizers,
out_dtype=ctx.activation_dtype,
layout="NT",
grad=True,
......@@ -576,6 +605,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -596,6 +626,7 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
......@@ -745,6 +776,8 @@ class GroupedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
debug = self.is_debug_iter()
assert not isinstance(
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
......@@ -756,31 +789,24 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
if is_grad_enabled:
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
......@@ -799,6 +825,8 @@ class GroupedLinear(TransformerEngineBaseModule):
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
......@@ -808,6 +836,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
......@@ -898,3 +927,55 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
return weight_quantizers
def _get_quantizers(self):
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_input_quantizers, grad_weight_quantizers, grad_output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
[None] * self.num_gemms,
)
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
return (
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
)
def _get_debug_quantizers(self):
original_quantizers = self._get_quantizers()
assert TEDebugState.debug_enabled
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
[
DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group)
for q_id, q in enumerate(qs)
]
for name, qs in zip(names, original_quantizers)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment