Unverified Commit dfe1a65a authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid spurious warning with non-FP8 GroupedLinear (#1758)



* Avoid spurious warning with non-FP8 GroupedLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use `QuantizedTensorBase`
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b33dd083
...@@ -44,7 +44,7 @@ from ..graph import is_graph_capturing ...@@ -44,7 +44,7 @@ from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -183,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -183,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme # TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad: if weight_requires_grad:
for inputmat in inputmats: for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor): if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad: if inp.requires_grad:
for weight in weights_fp8: for weight in weights_fp8:
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
...@@ -300,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -300,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
) )
for weight, quantizer in zip(weights, ctx.weight_quantizers): for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor): if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage( weight.update_usage(
rowwise_usage=quantizer.rowwise_usage, rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage, columnwise_usage=quantizer.columnwise_usage,
...@@ -664,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -664,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced) produced)
""" """
assert not isinstance( assert not isinstance(
inp, QuantizedTensor inp, QuantizedTensorBase
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
...@@ -676,13 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -676,13 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8: if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn( warnings.warn(
"You are using quantized weights without quantized compute. " "You are using quantized weights without quantized compute. "
"Please make sure this is intentional." "Please make sure this is intentional."
) )
weight_tensors = [ weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
] ]
input_quantizers, weight_quantizers, output_quantizers = ( input_quantizers, weight_quantizers, output_quantizers = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment