"docs/en/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "6e8183092aca74e8c721035de9af0c37e9060578"
Unverified Commit f84647fd authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Remove Generic from SerializableObject typings (#4844)

parent bef663b8
...@@ -13,7 +13,7 @@ import sys ...@@ -13,7 +13,7 @@ import sys
import types import types
import warnings import warnings
from io import IOBase from io import IOBase
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast, Generic from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
import cloudpickle # use cloudpickle as backend for unserializable types and instances import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend import json_tricks # use json_tricks as serializer backend
...@@ -115,13 +115,13 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool: ...@@ -115,13 +115,13 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
) )
class SerializableObject(Generic[T], Traceable): class SerializableObject(Traceable): # should be (Generic[T], Traceable), but cloudpickle is unhappy with Generic.
""" """
Serializable object is a wrapper of existing python objects, that supports dump and load easily. Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``. Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
""" """
def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False): def __init__(self, symbol: Type, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False):
# use dict to avoid conflicts with user's getattr and setattr # use dict to avoid conflicts with user's getattr and setattr
self.__dict__['_nni_symbol'] = symbol self.__dict__['_nni_symbol'] = symbol
self.__dict__['_nni_args'] = args self.__dict__['_nni_args'] = args
...@@ -135,19 +135,19 @@ class SerializableObject(Generic[T], Traceable): ...@@ -135,19 +135,19 @@ class SerializableObject(Generic[T], Traceable):
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()} **{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
) )
def trace_copy(self) -> Union[T, 'SerializableObject']: def trace_copy(self) -> 'SerializableObject':
return SerializableObject( return SerializableObject(
self.trace_symbol, self.trace_symbol,
[copy.copy(arg) for arg in self.trace_args], [copy.copy(arg) for arg in self.trace_args],
{k: copy.copy(v) for k, v in self.trace_kwargs.items()}, {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
) )
def get(self) -> T: def get(self) -> Any:
if not self._get_nni_attr('call_super'): if not self._get_nni_attr('call_super'):
# Reinitialize # Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs) return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return cast(T, self) return self
@property @property
def trace_symbol(self) -> Any: def trace_symbol(self) -> Any:
......
...@@ -3,6 +3,7 @@ __pycache__ ...@@ -3,6 +3,7 @@ __pycache__
tuner_search_space.json tuner_search_space.json
tuner_result.txt tuner_result.txt
assessor_result.txt assessor_result.txt
serialize_result.txt
_generated_model.py _generated_model.py
_generated_model_*.py _generated_model_*.py
......
import sys
import torch.nn as nn
# sys.argv[1] == 0 -> dump
# sys.argv[1] == 1 -> load
import nni
from nni.retiarii import model_wrapper
@model_wrapper
class Net(nn.Module):
something = 1
import cloudpickle
if sys.argv[1] == '0':
cloudpickle.dump(Net, open('serialize_result.txt', 'wb'))
# nni.dump(Net, fp=open('serialize_result.txt', 'w'))
else:
obj = cloudpickle.load(open('serialize_result.txt', 'rb'))
# obj = nni.load(fp=open('serialize_result.txt'))
assert obj().something == 1
import math import math
import os
import pickle import pickle
import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
import pytest import pytest
import nni import nni
import torch import torch
import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
...@@ -375,3 +378,24 @@ def test_get(): ...@@ -375,3 +378,24 @@ def test_get():
obj2 = obj1.trace_copy() obj2 = obj1.trace_copy()
obj2.trace_kwargs['a'] = -1 obj2.trace_kwargs['a'] = -1
assert obj2.get().bar() == 0 assert obj2.get().bar() == 0
def test_model_wrapper_serialize():
from nni.retiarii import model_wrapper
@model_wrapper
class Model(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
model = Model(3)
dumped = nni.dump(model)
loaded = nni.load(dumped)
assert loaded.in_channels == 3
def test_model_wrapper_across_process():
main_file = os.path.join(os.path.dirname(__file__), 'imported', '_test_serializer_main.py')
subprocess.run([sys.executable, main_file, '0'], check=True)
subprocess.run([sys.executable, main_file, '1'], check=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment