Unverified Commit 655512c1 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Inference mode disables initializing quantized weights with column-wise usage (#1847)



* Do not initialize quantized weights with column-wise usage in inference mode
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use no-grad mode instead of inference mode in tests
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8d4bdbc2
......@@ -10,6 +10,7 @@ import torch
import pytest
import os
import transformer_engine.pytorch
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
......@@ -38,9 +39,11 @@ from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8Quantizer,
Float8Tensor,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols
......@@ -1338,3 +1341,80 @@ def test_sanity_checkpointing_on_callables():
# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)
@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
)
@pytest.mark.parametrize(
"quantization",
(None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"),
)
def test_inference_mode(
module_name: str,
quantization: Optional[str],
) -> None:
"""Test heuristics for initializing quantized weights"""
# Tensor dimensions
sequence_length = 32
hidden_size = 32
# Skip invalid configurations
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Construct quantization recipe
with_quantization = quantization not in (None, "None")
quantization_recipe = None
if quantization == "fp8_delayed_scaling":
quantization_recipe = recipe.DelayedScaling()
elif quantization == "fp8_current_scaling":
quantization_recipe = recipe.Float8CurrentScaling()
elif quantization == "mxfp8":
quantization_recipe = recipe.MXFP8BlockScaling()
# Construct module
module = None
with torch.no_grad():
with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe):
if module_name == "Linear":
module = Linear(hidden_size, hidden_size)
elif module_name == "LayerNormLinear":
module = LayerNormLinear(hidden_size, hidden_size)
elif module_name == "LayerNormMLP":
module = LayerNormMLP(hidden_size, hidden_size)
elif module_name == "GroupedLinear":
module = GroupedLinear(1, hidden_size, hidden_size)
elif module_name == "ops.Linear":
module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)
def check_weights():
"""Helper function to check that weight parameters have expected data"""
for param in module.parameters():
if isinstance(param, Float8Tensor):
assert param._data is not None, "Missing FP8 data"
assert (
param._transpose is None and param._transpose_invalid
), "FP8 transpose is not expected for inference"
if isinstance(param, MXFP8Tensor):
assert param._rowwise_data is not None, "Missing row-wise MXFP8 data"
assert (
param._columnwise_data is None
), "Column-wise MXFP8 data is not expected for inference"
# Check that modules have expected weights after initialization
check_weights()
# Check that modules have expected weights after forward pass
with torch.inference_mode():
x = torch.zeros(sequence_length, hidden_size, device="cuda")
kwargs = {}
if module_name == "GroupedLinear":
kwargs["m_splits"] = [sequence_length]
with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe):
y = module(x, **kwargs)
check_weights()
......@@ -1183,18 +1183,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with get_rng_state_tracker().fork():
init_fn(param)
# If primary weights are in fp8, wrap the parameter as FP8Tensor
# Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
# Keep high-precision values on CPU if needed
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()
# Configure quantizer
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert (
quantizer is not None
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
if quantizer is None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False
# Quantize parameter
param = quantizer(param)
# Redo parameter wrap in case we broke it above
......@@ -1202,6 +1207,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed
if high_precision_init_val is not None:
# - Master weights are initialized from model weights, if we use fp8 primary
......@@ -1245,7 +1252,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values
"""Get workspace buffer for weights and maybe update its values
The workspace buffer may be cached for future function calls.
......@@ -1271,12 +1278,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for debug quantization, this is dtype of the tensor.
"""
# FP8 primary weights
# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None:
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)
return tensor
......
......@@ -271,7 +271,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer
if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
......
......@@ -325,8 +325,8 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
......@@ -1762,9 +1762,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8()
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.from_float8()
fc2_weight = fc2_weight.dequantize()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
......
......@@ -384,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.internal = False
self.data = self._quantizer.quantize(tensor)
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment