Unverified Commit 201de5f7 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Use an empty torch tensor to indicate no fp8 information in extra_state (#1799)



* Use an empty torch tensor to indicate no fp8 information in extra_state
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

* Add huggingface from_pretrained / save_pretrained tests

Adds integration tests to ensure models containing TransformerLayer
objects can be saved and loaded using the from_pretrained and
save_pretrained methods.
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

---------
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 3baaf3ff
...@@ -44,6 +44,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro ...@@ -44,6 +44,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -123,7 +123,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -123,7 +123,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
) )
# Blackwell is not supported as of Triton 3.2.0, need custom internal build # Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision"]) test_reqs.extend(["numpy", "torchvision", "transformers"])
if "jax" in frameworks: if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.my_layer = TransformerLayer(
hidden_size=320,
num_attention_heads=16,
ffn_hidden_size=1024,
layer_number=None,
)
def forward(self, hidden_states, attention_mask):
return self.my_layer(hidden_states, attention_mask)
def test_save_hf_model(tmp_path):
model = SimpleTEModel(PretrainedConfig())
model.save_pretrained(tmp_path / "simple_te_model")
@pytest.mark.xfail(reason="This test is failing until huggingface/transformers#38155 is merged.")
def test_save_and_load_hf_model(tmp_path):
model = SimpleTEModel(PretrainedConfig())
model.save_pretrained(tmp_path / "simple_te_model")
del model
model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model")
assert model is not None
...@@ -731,7 +731,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -731,7 +731,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd") reset("scaling_fwd")
reset("scaling_bwd") reset("scaling_bwd")
def get_extra_state(self) -> Optional[torch.Tensor]: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
# This implementation is working around a few issues: # This implementation is working around a few issues:
...@@ -766,7 +766,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -766,7 +766,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state = None state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint: if not fp8_checkpoint:
return None return torch.empty(0, dtype=torch.uint8)
# Copy tensors to CPU and store # Copy tensors to CPU and store
state = {} state = {}
...@@ -792,13 +792,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -792,13 +792,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized return state_serialized
def set_extra_state(self, state: Optional[torch.Tensor]) -> None: def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state.""" """Load previous state."""
if state is None:
return
# Load state # Load state
if isinstance(state, torch.Tensor): if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
if state.numel() == 0:
return
# Default format: byte tensor with pickled data # Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes()) state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO): elif isinstance(state, io.BytesIO):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment