"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "bbddcb92896f604a478c9e94ab697c71d838638f"
Unverified Commit 201de5f7 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Use an empty torch tensor to indicate no fp8 information in extra_state (#1799)



* Use an empty torch tensor to indicate no fp8 information in extra_state
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

* Add huggingface from_pretrained / save_pretrained tests

Adds integration tests to ensure models containing TransformerLayer
objects can be saved and loaded using the from_pretrained and
save_pretrained methods.
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>

---------
Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 3baaf3ff
......@@ -44,6 +44,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -123,7 +123,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision"])
test_reqs.extend(["numpy", "torchvision", "transformers"])
if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(["jax", "flax>=0.7.1"])
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.my_layer = TransformerLayer(
hidden_size=320,
num_attention_heads=16,
ffn_hidden_size=1024,
layer_number=None,
)
def forward(self, hidden_states, attention_mask):
return self.my_layer(hidden_states, attention_mask)
def test_save_hf_model(tmp_path):
model = SimpleTEModel(PretrainedConfig())
model.save_pretrained(tmp_path / "simple_te_model")
@pytest.mark.xfail(reason="This test is failing until huggingface/transformers#38155 is merged.")
def test_save_and_load_hf_model(tmp_path):
model = SimpleTEModel(PretrainedConfig())
model.save_pretrained(tmp_path / "simple_te_model")
del model
model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model")
assert model is not None
......@@ -731,7 +731,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> Optional[torch.Tensor]:
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
# This implementation is working around a few issues:
......@@ -766,7 +766,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint:
return None
return torch.empty(0, dtype=torch.uint8)
# Copy tensors to CPU and store
state = {}
......@@ -792,13 +792,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
if state is None:
return
# Load state
if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
if state.numel() == 0:
return
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment