Unverified Commit 115a27ef authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fixed bug with loading calibrated weights (#771)



* Calibration fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 06539514
...@@ -17,3 +17,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_a ...@@ -17,3 +17,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_a
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
\ No newline at end of file
...@@ -66,6 +66,9 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -66,6 +66,9 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
self.weights_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision self.outp_type = precision
def get_fp8_weights_scratchpad(self, is_first_microbatch):
raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.")
def forward(self, inp, weight): def forward(self, inp, weight):
inp_fp8 = cast_to_fp8( inp_fp8 = cast_to_fp8(
inp, inp,
...@@ -145,14 +148,11 @@ def test_fp8_model_checkpoint( ...@@ -145,14 +148,11 @@ def test_fp8_model_checkpoint(
params_dtype=dtype, params_dtype=dtype,
device=device, device=device,
) )
# Keep track of model output # Keep track of model output
x = torch.randn(dims, dtype=dtype, device=device) x = torch.randn(dims, dtype=dtype, device=device)
with te.fp8_autocast(): with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone() y_ref = model(x.detach().clone()).detach().clone()
# Keep track of weights and FP8 scaling factors
weight_ref = model.weight.float().detach().clone()
fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} }
with te.fp8_autocast(), torch.no_grad(): with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"] fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
...@@ -169,6 +169,18 @@ def test_fp8_model_checkpoint( ...@@ -169,6 +169,18 @@ def test_fp8_model_checkpoint(
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd del fp8_meta_fwd, fp8_meta_bwd
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor.
# The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method.
# It is essential for these values to be equal, so setting scale_inv only in the model metadata is insufficient.
model.weight.data.copy_(model.weight.float().cuda())
# After copying, the tensor computes the meta scale_inv based on the amax history; we then reset these values.
model.fp8_meta["scaling_fwd"].scale = fp8_meta_fwd_ref["scale"]
model.fp8_meta["scaling_fwd"].scale_inv = fp8_meta_fwd_ref["scale_inv"]
# Keep track of weights and FP8 scaling factors
weight_ref = model.weight.float().detach().clone()
# Save checkpoint # Save checkpoint
byte_stream = io.BytesIO() byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream) torch.save(model.state_dict(), byte_stream)
...@@ -214,6 +226,18 @@ def test_fp8_model_checkpoint( ...@@ -214,6 +226,18 @@ def test_fp8_model_checkpoint(
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols) torch.testing.assert_close(y, y_ref, **tols)
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# When save_fp8_model=True, we load a model with weights in high precision,
# which does not include _scale_inv,
# but has the fp8 scaling factor in the meta data. This scenario can occur
# when using te.fp8_autocast(enabled=False, calibrating=True).
#
# In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first,
# followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior
# is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule,
# to load the fp8 metadata before loading tensors.
#
# Load checkpoint # Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes))) model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes del model_bytes
...@@ -232,3 +256,10 @@ def test_fp8_model_checkpoint( ...@@ -232,3 +256,10 @@ def test_fp8_model_checkpoint(
with te.fp8_autocast(): with te.fp8_autocast():
y = model(x.detach().clone()) y = model(x.detach().clone())
torch.testing.assert_close(y, y_ref, **tols) torch.testing.assert_close(y, y_ref, **tols)
if load_fp8_model:
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# We need to ensure that the tensor's scale_inv parameter matches its meta data.
# This is crucial to avoid confusion about which value is correct.
meta_index = model.weight._fp8_meta_index
torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item())
\ No newline at end of file
...@@ -858,3 +858,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -858,3 +858,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Needs override.""" """Needs override."""
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""
This function loads tensors and extra state including fp8 metadata.
This metadata is essential for copying fp8 tensors, as the copy_ function
uses the scale_inv parameter from fp8_meta to set the correct scaling factor
for the new tensor.
Hence, this extra state must be loaded before the tensor copying process,
not after, as is typically done in _load_from_state_dict.
Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
otherwise, this behavior is not required.
"""
if self.primary_weights_in_fp8:
extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment