"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "9d468d2c742491af2d2f506c648ddc95ffea6a64"
Unverified Commit e1ceae41 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Fix record conflicts (#3345)

parent 64efd60a
...@@ -4,7 +4,7 @@ from typing import Any, List ...@@ -4,7 +4,7 @@ from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record, blackbox_module, uid, version_larger_equal from ...utils import add_record, blackbox_module, del_record, uid, version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -93,6 +93,9 @@ class Placeholder(nn.Module): ...@@ -93,6 +93,9 @@ class Placeholder(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
def __del__(self):
del_record(id(self))
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
""" """
...@@ -133,12 +136,18 @@ class Sequential(nn.Sequential): ...@@ -133,12 +136,18 @@ class Sequential(nn.Sequential):
add_record(id(self), {}) add_record(id(self), {})
super(Sequential, self).__init__(*args) super(Sequential, self).__init__(*args)
def __del__(self):
del_record(id(self))
class ModuleList(nn.ModuleList): class ModuleList(nn.ModuleList):
def __init__(self, *args): def __init__(self, *args):
add_record(id(self), {}) add_record(id(self), {})
super(ModuleList, self).__init__(*args) super(ModuleList, self).__init__(*args)
def __del__(self):
del_record(id(self))
Identity = blackbox_module(nn.Identity) Identity = blackbox_module(nn.Identity)
Linear = blackbox_module(nn.Linear) Linear = blackbox_module(nn.Linear)
......
...@@ -37,7 +37,7 @@ def add_record(key, value): ...@@ -37,7 +37,7 @@ def add_record(key, value):
""" """
global _records global _records
if _records is not None: if _records is not None:
assert key not in _records, '{} already in _records'.format(key) assert key not in _records, f'{key} already in _records. Conflict: {_records[key]}'
_records[key] = value _records[key] = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment