"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "1b20f2d67f02ad6eca3a75c7477e8fdedb05dc58"
Unverified Commit d5d78333 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Quantizer update when recipe was changed (#1814)



* Quantizer update
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

* Update import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Introduce _update_weight_quantizers and _get_weight_tensors/_get_weight_quantizers
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add test
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Move _quantizer to the QuantizedTensorBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>

---------
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@etsykunov-mlt.client.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent 204add8f
......@@ -20,7 +20,7 @@ from transformer_engine.pytorch.fp8 import (
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
......@@ -470,3 +470,34 @@ class TestFP8Recipe:
# Final check
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
@pytest.mark.parametrize(
"module_class",
[
Linear,
LayerNormLinear,
LayerNormMLP,
GroupedLinear,
],
)
def test_quantizer_update(self, module_class):
in_features = 32
out_features = 32
batch_size = 32
recipe = DelayedScaling(amax_history_len=1024)
with fp8_model_init(recipe=recipe):
if module_class == GroupedLinear:
module = module_class(1, in_features, out_features).cuda()
else:
module = module_class(in_features, out_features).cuda()
x = torch.randn(batch_size, in_features, device="cuda")
recipe = DelayedScaling(amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
warn_msg = "Quantizer is being updated, this may affect model behavior"
with pytest.warns(UserWarning, match=warn_msg):
if module_class == GroupedLinear:
y = module(x, [batch_size])
else:
y = module(x)
......@@ -639,6 +639,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update quantizers with new amax pointers.
self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
# Make sure weight tensors has correct quantizers
self._update_weight_quantizers()
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
......@@ -692,6 +694,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()
def _update_weight_quantizers(self) -> None:
"""Update the quantizers for the weight tensors."""
weight_tensors = self._get_weight_tensors()
weight_quantizers = self._get_weight_quantizers()
assert len(weight_tensors) == len(weight_quantizers), (
f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
f"({len(weight_quantizers)}) must match"
)
for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
)
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
)
def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True, recipe)
......
......@@ -673,21 +673,11 @@ class GroupedLinear(TransformerEngineBaseModule):
is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = (
[None] * self.num_gemms,
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
......@@ -702,14 +692,6 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
......@@ -808,3 +790,30 @@ class GroupedLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors
]
return weight_tensors
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
return weight_quantizers
......@@ -5,7 +5,7 @@
"""LayerNormLinear API"""
import os
import warnings
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
......@@ -1471,20 +1471,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......@@ -1590,8 +1577,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled():
......@@ -1666,3 +1652,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
......@@ -5,7 +5,7 @@
"""LayerNormMLP API"""
import os
import warnings
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
......@@ -1754,9 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = quantizers
# Get weight tensors
fc1_weight = self.fc1_weight
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_weight = self.fc2_weight
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
......@@ -1847,31 +1846,26 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output):
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = [None] * 12
) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
)
fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
......@@ -1988,6 +1982,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Linear API"""
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
import warnings
......@@ -1290,20 +1290,7 @@ class Linear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......@@ -1333,12 +1320,6 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer,
) = quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
weight_tensor._quantizer = weight_quantizer
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
......@@ -1399,8 +1380,7 @@ class Linear(TransformerEngineBaseModule):
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled():
......@@ -1474,3 +1454,28 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
......@@ -8,6 +8,7 @@ from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import torch
from torch.utils._pytree import tree_map
......@@ -32,6 +33,8 @@ class QuantizedTensorBase:
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
......@@ -70,6 +73,14 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
raise RuntimeError("To be updated, quantizer must be set")
if self._quantizer is not quantizer:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment