Unverified Commit 655512c1 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Inference mode disables initializing quantized weights with column-wise usage (#1847)



* Do not initialize quantized weights with column-wise usage in inference mode
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use no-grad mode instead of inference mode in tests
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8d4bdbc2
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import pytest import pytest
import os import os
import transformer_engine.pytorch
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
fp8_autocast, fp8_autocast,
FP8GlobalStateManager, FP8GlobalStateManager,
...@@ -38,9 +39,11 @@ from transformer_engine.pytorch.cpp_extensions import general_gemm ...@@ -38,9 +39,11 @@ from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols from test_numerics import reset_rng_states, dtype_tols
...@@ -1338,3 +1341,80 @@ def test_sanity_checkpointing_on_callables(): ...@@ -1338,3 +1341,80 @@ def test_sanity_checkpointing_on_callables():
# Assert that gradients are the same # Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard) torch.testing.assert_close(grad_checkpoint, grad_standard)
@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
"quantization",
(None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
module_name: str,
quantization: Optional[str],
) -> None:
"""Test heuristics for initializing quantized weights"""
# Tensor dimensions
sequence_length = 32
hidden_size = 32
# Skip invalid configurations
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Construct quantization recipe
with_quantization = quantization not in (None, "None")
quantization_recipe = None
if quantization == "fp8_delayed_scaling":
quantization_recipe = recipe.DelayedScaling()
elif quantization == "fp8_current_scaling":
quantization_recipe = recipe.Float8CurrentScaling()
elif quantization == "mxfp8":
quantization_recipe = recipe.MXFP8BlockScaling()
# Construct module
module = None
with torch.no_grad():
with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(hidden_size, hidden_size)
elif module_name == "LayerNormLinear":
module = LayerNormLinear(hidden_size, hidden_size)
elif module_name == "LayerNormMLP":
module = LayerNormMLP(hidden_size, hidden_size)
elif module_name == "GroupedLinear":
module = GroupedLinear(1, hidden_size, hidden_size)
elif module_name == "ops.Linear":
module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)
def check_weights():
"""Helper function to check that weight parameters have expected data"""
for param in module.parameters():
if isinstance(param, Float8Tensor):
assert param._data is not None, "Missing FP8 data"
assert (
param._transpose is None and param._transpose_invalid
), "FP8 transpose is not expected for inference"
if isinstance(param, MXFP8Tensor):
assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
assert (
param._columnwise_data is None
), "Column-wise MXFP8 data is not expected for inference"
# Check that modules have expected weights after initialization
check_weights()
# Check that modules have expected weights after forward pass
with torch.inference_mode():
x = torch.zeros(sequence_length, hidden_size, device="cuda")
kwargs = {}
if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length]
with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
y = module(x, **kwargs)
check_weights()
...@@ -1183,18 +1183,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1183,18 +1183,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with get_rng_state_tracker().fork(): with get_rng_state_tracker().fork():
init_fn(param) init_fn(param)
# If primary weights are in fp8, wrap the parameter as FP8Tensor # Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None: if self.primary_weights_in_fp8 and fp8_meta_index is not None:
# Keep high-precision values on CPU if needed
if self.preserve_high_precision_init_val: if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu() high_precision_init_val = param.detach().cpu()
# Configure quantizer
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert ( if quantizer is None:
quantizer is not None raise RuntimeError("Weight quantizer has not been initialized")
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False quantizer.internal = False
# Quantize parameter
param = quantizer(param) param = quantizer(param)
# Redo parameter wrap in case we broke it above # Redo parameter wrap in case we broke it above
...@@ -1202,6 +1207,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1202,6 +1207,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# re-applying the nn.Parameter() wrap is a no-op when the input is already # re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety. # a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param) param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed
if high_precision_init_val is not None: if high_precision_init_val is not None:
# - Master weights are initialized from model weights, if we use fp8 primary # - Master weights are initialized from model weights, if we use fp8 primary
...@@ -1245,7 +1252,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1245,7 +1252,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fsdp_group: Optional[dist_group_type] = None, fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None, workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values """Get workspace buffer for weights and maybe update its values
The workspace buffer may be cached for future function calls. The workspace buffer may be cached for future function calls.
...@@ -1271,12 +1278,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1271,12 +1278,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for debug quantization, this is dtype of the tensor. for debug quantization, this is dtype of the tensor.
""" """
# FP8 primary weights # Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None: update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage( tensor.update_usage(
rowwise_usage=quantizer.rowwise_usage, rowwise_usage=update_rowwise_usage,
columnwise_usage=quantizer.columnwise_usage, columnwise_usage=update_columnwise_usage,
) )
return tensor return tensor
......
...@@ -271,7 +271,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -271,7 +271,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight # Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
......
...@@ -325,8 +325,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -325,8 +325,8 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc. # which handles weight caching etc.
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace( fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight, tensor=fc1_weight,
quantizer=fc1_weight_quantizer, quantizer=fc1_weight_quantizer,
...@@ -1762,9 +1762,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1762,9 +1762,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_bias = self.fc2_bias if self.use_bias else None fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8: if not self.fp8:
if isinstance(fc1_weight, Float8Tensor): if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8() fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor): if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.from_float8() fc2_weight = fc2_weight.dequantize()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
......
...@@ -384,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -384,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Quantize to FP8 # Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer" assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.internal = False
self.data = self._quantizer.quantize(tensor) self.data = self._quantizer.quantize(tensor)
if self.requires_grad != tensor.requires_grad: if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad) self.requires_grad_(requires_grad=tensor.requires_grad)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment