Unverified Commit c0a539c6 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix ONNX export bug with operation-based API (#1320)



Debug ONNX export with te.Sequential

ONNX export assumes that all state dict objects are tensor, even extra state.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 943f1e0a
......@@ -505,7 +505,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
basic_op_kwargs=[kwargs],
)
def get_extra_state(self) -> Optional[torch.Tensor]:
def get_extra_state(self) -> torch.Tensor:
"""Serialize extra state
Contains metadata for FP8 casting.
......@@ -534,7 +534,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
)
if not has_fp8_state:
return None
return torch.Tensor()
def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
......@@ -588,7 +588,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load extra state"""
if state is None:
if state is None or state.numel() == 0:
return
# Deserialize state from byte tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment