Unverified Commit e1ceae41 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Fix record conflicts (#3345)

parent 64efd60a
......@@ -4,7 +4,7 @@ from typing import Any, List
import torch
import torch.nn as nn
from ...utils import add_record, blackbox_module, uid, version_larger_equal
from ...utils import add_record, blackbox_module, del_record, uid, version_larger_equal
_logger = logging.getLogger(__name__)
......@@ -93,6 +93,9 @@ class Placeholder(nn.Module):
def forward(self, x):
return x
def __del__(self):
del_record(id(self))
class ChosenInputs(nn.Module):
"""
......@@ -133,12 +136,18 @@ class Sequential(nn.Sequential):
add_record(id(self), {})
super(Sequential, self).__init__(*args)
def __del__(self):
del_record(id(self))
class ModuleList(nn.ModuleList):
def __init__(self, *args):
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
def __del__(self):
del_record(id(self))
Identity = blackbox_module(nn.Identity)
Linear = blackbox_module(nn.Linear)
......
......@@ -37,7 +37,7 @@ def add_record(key, value):
"""
global _records
if _records is not None:
assert key not in _records, '{} already in _records'.format(key)
assert key not in _records, f'{key} already in _records. Conflict: {_records[key]}'
_records[key] = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment