Unverified Commit 0a57438b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix unintended modification of trace symbol after dump (#4993)

parent baf60758
......@@ -538,6 +538,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
# Pickle can't handle type objects.
if '_nni_symbol' in obj_:
obj_ = dict(obj_) # copy the object to keep the original symbol unchanged
obj_['_nni_symbol'] = cloudpickle.dumps(obj_['_nni_symbol'])
return _pickling_object, (type_, kw_only, obj_)
......
......@@ -371,6 +371,23 @@ def test_subclass():
assert isinstance(obj, Super)
class ConsistencyTest1:
pass
class ConsistencyTest2:
def __init__(self):
self.test = nni.trace(ConsistencyTest1)()
def test_dump_consistency():
test2 = ConsistencyTest2()
symbol1 = test2.test.trace_symbol
pickle.dumps(test2)
symbol2 = test2.test.trace_symbol
assert symbol1 == symbol2
def test_get():
@nni.trace
class Foo:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment