Unverified Commit dfe1a65a authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid spurious warning with non-FP8 GroupedLinear (#1758)



* Avoid spurious warning with non-FP8 GroupedLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use `QuantizedTensorBase`
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b33dd083
......@@ -44,7 +44,7 @@ from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -183,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor):
if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensor):
if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving(
......@@ -300,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
......@@ -664,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert not isinstance(
inp, QuantizedTensor
inp, QuantizedTensorBase
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
......@@ -676,13 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment