Unverified Commit 315b47db authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug linear layer when saving original input and using debug quantizer (#1963)



* Debug linear layer when saving original input and using debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Workaround bugs with quantizing with only column-wise usage
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unused imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary row-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Workaround bugs with quantizing with only column-wise usage

FP8 does not support transpose-only cast.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0d802283
......@@ -65,8 +65,6 @@ from ..tensor.quantized_tensor import (
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
......@@ -170,16 +168,19 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
if save_original_input:
# No need for column-wise data since this
# tensor will not be cached for backward pass
input_quantizer.set_usage(columnwise=False)
own_quantized_input = False
inputmat = input_quantizer(inputmat)
own_quantized_input = True
else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
......@@ -344,23 +345,29 @@ class _Linear(torch.autograd.Function):
inputmat = inp
ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
# Discard unneeded data in input tensor
if (
backward_needs_input
and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
):
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
else:
# Discard row-wise data since it is not needed in backward pass
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
# Cached input tensor
saved_inputmat = None
if backward_needs_input:
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -572,20 +579,26 @@ class _Linear(torch.autograd.Function):
inputmat_total = None
inputmat_total_work = None
if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug:
if not input_is_quantized:
if isinstance(inputmat, QuantizedTensorBase):
# Input tensor is already quantized
pass
elif ctx.debug:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
# Quantize input tensor
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
)
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
quantizer.set_usage(rowwise=True, columnwise=False)
else:
quantizer.set_usage(rowwise=False, columnwise=True)
quantizer.set_usage(rowwise=True, columnwise=True)
inputmat = quantizer(inputmat)
else:
if input_is_quantized:
if isinstance(inputmat, QuantizedTensorBase):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment