Unverified Commit c0a539c6 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix ONNX export bug with operation-based API (#1320)



Debug ONNX export with te.Sequential

ONNX export assumes that all state dict objects are tensor, even extra state.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 943f1e0a
...@@ -505,7 +505,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -505,7 +505,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
basic_op_kwargs=[kwargs], basic_op_kwargs=[kwargs],
) )
def get_extra_state(self) -> Optional[torch.Tensor]: def get_extra_state(self) -> torch.Tensor:
"""Serialize extra state """Serialize extra state
Contains metadata for FP8 casting. Contains metadata for FP8 casting.
...@@ -534,7 +534,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -534,7 +534,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
) )
if not has_fp8_state: if not has_fp8_state:
return None return torch.Tensor()
def to_cpu(src: torch.Tensor) -> torch.Tensor: def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor """Helper function to make CPU copy of tensor
...@@ -588,7 +588,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -588,7 +588,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def set_extra_state(self, state: Optional[torch.Tensor]) -> None: def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load extra state""" """Load extra state"""
if state is None: if state is None or state.numel() == 0:
return return
# Deserialize state from byte tensor # Deserialize state from byte tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment