Unverified Commit edcfc284 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Warn when using fp8 weights + non-fp8 computation (#1712)



* Prevent using fp8 weights + non-fp8 computation
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Using warnings instead of raising an error
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add dequantization back
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 4e9c2c39
......@@ -4,6 +4,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import warnings
import functools
import torch
......@@ -676,6 +677,10 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
]
......
......@@ -1394,6 +1394,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
......
......@@ -6,6 +6,7 @@
from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import warnings
import functools
import torch
......@@ -1207,7 +1208,12 @@ class Linear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment