Unverified Commit 315b47db authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug linear layer when saving original input and using debug quantizer (#1963)



* Debug linear layer when saving original input and using debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Workaround bugs with quantizing with only column-wise usage
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unused imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid unnecessary row-wise data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Workaround bugs with quantizing with only column-wise usage

FP8 does not support transpose-only cast.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0d802283
...@@ -65,8 +65,6 @@ from ..tensor.quantized_tensor import ( ...@@ -65,8 +65,6 @@ from ..tensor.quantized_tensor import (
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -170,16 +168,19 @@ class _Linear(torch.autograd.Function): ...@@ -170,16 +168,19 @@ class _Linear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase): if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage( own_quantized_input = True
rowwise=True, columnwise=backward_needs_input and not save_original_input input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
)
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
if save_original_input:
# No need for column-wise data since this
# tensor will not be cached for backward pass
input_quantizer.set_usage(columnwise=False)
own_quantized_input = False
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True
else: else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
...@@ -344,23 +345,29 @@ class _Linear(torch.autograd.Function): ...@@ -344,23 +345,29 @@ class _Linear(torch.autograd.Function):
inputmat = inp inputmat = inp
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
saved_inputmat = None
ctx.backward_input_needs_gather = ( ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
if backward_needs_input: # Discard unneeded data in input tensor
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if ( if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) backward_needs_input
or not ctx.backward_input_needs_gather and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
): ):
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
else:
# Discard row-wise data since it is not needed in backward pass
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
# Cached input tensor
saved_inputmat = None
if backward_needs_input:
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -572,20 +579,26 @@ class _Linear(torch.autograd.Function): ...@@ -572,20 +579,26 @@ class _Linear(torch.autograd.Function):
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if not input_is_quantized: if isinstance(inputmat, QuantizedTensorBase):
# Input tensor is already quantized
pass
elif ctx.debug:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
# Quantize input tensor
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if ctx.backward_input_needs_gather and isinstance(
quantizer.set_usage( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
rowwise=True, ):
columnwise=not ctx.backward_input_needs_gather, # All-gather is not supported with FP8 column-wise data
) quantizer.set_usage(rowwise=True, columnwise=False)
else: else:
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=True, columnwise=True)
inputmat = quantizer(inputmat) inputmat = quantizer(inputmat)
else: else:
if input_is_quantized: if isinstance(inputmat, QuantizedTensorBase):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else: else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype) inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment