Unverified Commit edcfc284 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Warn when using fp8 weights + non-fp8 computation (#1712)



* Prevent using fp8 weights + non-fp8 computation
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Using warnings instead of raising an error
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add dequantization back
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 4e9c2c39
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
import warnings
import functools import functools
import torch import torch
...@@ -676,6 +677,10 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -676,6 +677,10 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8: if not self.fp8:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [ weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
] ]
......
...@@ -1394,6 +1394,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1394,6 +1394,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported" "Splitting QuantizedTensor into multiple params is not supported"
) )
else: else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import warnings
import functools import functools
import torch import torch
...@@ -1207,7 +1208,12 @@ class Linear(TransformerEngineBaseModule): ...@@ -1207,7 +1208,12 @@ class Linear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported" "Splitting QuantizedTensor into multiple params is not supported"
) )
else: else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment