"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "ce2e8bd12edfe10647bec8f54fedc394d6287b58"
Unverified Commit d5d78333 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Quantizer update when recipe was changed (#1814)



* Quantizer update
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

* Update import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Introduce _update_weight_quantizers and _get_weight_tensors/_get_weight_quantizers
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add test
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Move _quantizer to the QuantizedTensorBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

---------
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent 204add8f
...@@ -20,7 +20,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -20,7 +20,7 @@ from transformer_engine.pytorch.fp8 import (
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -470,3 +470,34 @@ class TestFP8Recipe: ...@@ -470,3 +470,34 @@ class TestFP8Recipe:
# Final check # Final check
for quantizer in linear.quantizers["scaling_fwd"]: for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type) assert isinstance(quantizer, expected_quantizer_type)
@pytest.mark.parametrize(
"module_class",
[
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
],
)
def test_quantizer_update(self, module_class):
in_features = 32
out_features = 32
batch_size = 32
recipe = DelayedScaling(amax_history_len=1024)
with fp8_model_init(recipe=recipe):
if module_class == GroupedLinear:
module = module_class(1, in_features, out_features).cuda()
else:
module = module_class(in_features, out_features).cuda()
x = torch.randn(batch_size, in_features, device="cuda")
recipe = DelayedScaling(amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
warn_msg = "Quantizer is being updated, this may affect model behavior"
with pytest.warns(UserWarning, match=warn_msg):
if module_class == GroupedLinear:
y = module(x, [batch_size])
else:
y = module(x)
...@@ -639,6 +639,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -639,6 +639,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update quantizers with new amax pointers. # Update quantizers with new amax pointers.
self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
# Make sure weight tensors has correct quantizers
self._update_weight_quantizers()
# Update the global buffers with new amax and history pointers. # Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
...@@ -692,6 +694,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -692,6 +694,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()
def _update_weight_quantizers(self) -> None:
"""Update the quantizers for the weight tensors."""
weight_tensors = self._get_weight_tensors()
weight_quantizers = self._get_weight_quantizers()
assert len(weight_tensors) == len(weight_quantizers), (
f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
f"({len(weight_quantizers)}) must match"
)
for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
)
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
)
def init_fp8_meta_tensors(self, recipe: Recipe) -> None: def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
"""Init scales and amaxes.""" """Init scales and amaxes."""
self.set_meta_tensor(True, recipe) self.set_meta_tensor(True, recipe)
......
...@@ -673,21 +673,11 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -673,21 +673,11 @@ class GroupedLinear(TransformerEngineBaseModule):
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = ( weight_quantizers = self._get_weight_quantizers()
[None] * self.num_gemms, input_quantizers, output_quantizers = (
[None] * self.num_gemms, [None] * self.num_gemms,
[None] * self.num_gemms, [None] * self.num_gemms,
) )
...@@ -702,14 +692,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -702,14 +692,6 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme # TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms): for i in range(self.num_gemms):
input_quantizers[i].internal = False input_quantizers[i].internal = False
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_output_quantizers = [ grad_output_quantizers = [
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -808,3 +790,30 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -808,3 +790,30 @@ class GroupedLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors
]
return weight_tensors
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
return weight_quantizers
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
import warnings import warnings
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
...@@ -1471,20 +1471,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1471,20 +1471,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
...@@ -1590,8 +1577,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1590,8 +1577,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] (weight_quantizer,) = self._get_weight_quantizers()
weight_quantizer.internal = True
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
...@@ -1666,3 +1652,28 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1666,3 +1652,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""LayerNormMLP API""" """LayerNormMLP API"""
import os import os
import warnings import warnings
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union, List
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
...@@ -1754,9 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1754,9 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = quantizers ) = quantizers
# Get weight tensors # Get weight tensors
fc1_weight = self.fc1_weight fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None fc1_bias = self.fc1_bias if self.use_bias else None
fc2_weight = self.fc2_weight
fc2_bias = self.fc2_bias if self.use_bias else None fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8: if not self.fp8:
if isinstance(fc1_weight, Float8Tensor): if isinstance(fc1_weight, Float8Tensor):
...@@ -1847,31 +1846,26 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1847,31 +1846,26 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output): def _get_quantizers(self, fp8_output):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer, fc1_output_quantizer,
fc1_grad_input_quantizer, fc1_grad_input_quantizer,
fc1_grad_weight_quantizer, fc1_grad_weight_quantizer,
fc1_grad_output_quantizer, fc1_grad_output_quantizer,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer, fc2_output_quantizer,
fc2_grad_input_quantizer, fc2_grad_input_quantizer,
fc2_grad_weight_quantizer, fc2_grad_weight_quantizer,
fc2_grad_output_quantizer, fc2_grad_output_quantizer,
) = [None] * 12 ) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8: if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage( fc2_input_quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
) )
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output: if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][ fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT tex.FP8FwdTensors.GEMM2_OUTPUT
...@@ -1988,6 +1982,20 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1988,6 +1982,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import warnings import warnings
...@@ -1290,20 +1290,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1290,20 +1290,7 @@ class Linear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
...@@ -1333,12 +1320,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1333,12 +1320,6 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
weight_tensor._quantizer = weight_quantizer
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] args = []
...@@ -1399,8 +1380,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1399,8 +1380,7 @@ class Linear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] (weight_quantizer,) = self._get_weight_quantizers()
weight_quantizer.internal = True
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
...@@ -1474,3 +1454,28 @@ class Linear(TransformerEngineBaseModule): ...@@ -1474,3 +1454,28 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
...@@ -8,6 +8,7 @@ from __future__ import annotations ...@@ -8,6 +8,7 @@ from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy import copy
import warnings
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -32,6 +33,8 @@ class QuantizedTensorBase: ...@@ -32,6 +33,8 @@ class QuantizedTensorBase:
XTensor should only implement the functionality needed XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__).""" to behave like regular torch.Tensor (liek __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
def update_usage( def update_usage(
self, self,
rowwise_usage: Optional[bool] = None, rowwise_usage: Optional[bool] = None,
...@@ -70,6 +73,14 @@ class QuantizedTensorBase: ...@@ -70,6 +73,14 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function" f"{self.__class__.__name__} class does not implement restore_from_saved function"
) )
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
raise RuntimeError("To be updated, quantizer must be set")
if self._quantizer is not quantizer:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def prepare_for_saving( def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase], *tensors: Union[torch.Tensor, QuantizedTensorBase],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment