Unverified Commit f84647fd authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Remove Generic from SerializableObject typings (#4844)

parent bef663b8
......@@ -13,7 +13,7 @@ import sys
import types
import warnings
from io import IOBase
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast, Generic
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend
......@@ -115,13 +115,13 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
)
class SerializableObject(Generic[T], Traceable):
class SerializableObject(Traceable): # should be (Generic[T], Traceable), but cloudpickle is unhappy with Generic.
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
"""
def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False):
def __init__(self, symbol: Type, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False):
# use dict to avoid conflicts with user's getattr and setattr
self.__dict__['_nni_symbol'] = symbol
self.__dict__['_nni_args'] = args
......@@ -135,19 +135,19 @@ class SerializableObject(Generic[T], Traceable):
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
)
def trace_copy(self) -> Union[T, 'SerializableObject']:
def trace_copy(self) -> 'SerializableObject':
return SerializableObject(
self.trace_symbol,
[copy.copy(arg) for arg in self.trace_args],
{k: copy.copy(v) for k, v in self.trace_kwargs.items()},
)
def get(self) -> T:
def get(self) -> Any:
if not self._get_nni_attr('call_super'):
# Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return cast(T, self)
return self
@property
def trace_symbol(self) -> Any:
......
......@@ -3,6 +3,7 @@ __pycache__
tuner_search_space.json
tuner_result.txt
assessor_result.txt
serialize_result.txt
_generated_model.py
_generated_model_*.py
......
import sys
import torch.nn as nn
# sys.argv[1] == 0 -> dump
# sys.argv[1] == 1 -> load
import nni
from nni.retiarii import model_wrapper
@model_wrapper
class Net(nn.Module):
something = 1
import cloudpickle
if sys.argv[1] == '0':
cloudpickle.dump(Net, open('serialize_result.txt', 'wb'))
# nni.dump(Net, fp=open('serialize_result.txt', 'w'))
else:
obj = cloudpickle.load(open('serialize_result.txt', 'rb'))
# obj = nni.load(fp=open('serialize_result.txt'))
assert obj().something == 1
import math
import os
import pickle
import subprocess
import sys
from pathlib import Path
import pytest
import nni
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
......@@ -375,3 +378,24 @@ def test_get():
obj2 = obj1.trace_copy()
obj2.trace_kwargs['a'] = -1
assert obj2.get().bar() == 0
def test_model_wrapper_serialize():
from nni.retiarii import model_wrapper
@model_wrapper
class Model(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
model = Model(3)
dumped = nni.dump(model)
loaded = nni.load(dumped)
assert loaded.in_channels == 3
def test_model_wrapper_across_process():
main_file = os.path.join(os.path.dirname(__file__), 'imported', '_test_serializer_main.py')
subprocess.run([sys.executable, main_file, '0'], check=True)
subprocess.run([sys.executable, main_file, '1'], check=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment