Unverified Commit 33a3d02f authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Update Sequential container to handle changes in module base class (#1028)



* Update sequential container constructor to handle modules in plain dicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid initializing Sequential with dicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 238df4ce
......@@ -5,7 +5,6 @@
"""Sequential container for fusible operations."""
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from typing import Optional
......@@ -39,7 +38,7 @@ class Sequential(torch.nn.Module):
self._module_groups = None
# Add modules
if len(args) == 1 and isinstance(args[0], OrderedDict):
if len(args) == 1 and isinstance(args[0], dict):
for key, module in args[0].items():
self.add_module(key, module)
else:
......@@ -82,8 +81,9 @@ class Sequential(torch.nn.Module):
) -> Sequential | torch.nn.Module:
keys = self._get_keys_by_idx(idx)
if isinstance(idx, slice):
modules = OrderedDict((str(i), self._modules[key]) for i, key in enumerate(keys))
return self.__class__(modules)
out = Sequential()
out.extend(self._modules[key] for key in keys)
return out
return self._modules[keys[0]]
def __setitem__(self, idx: int, module: torch.nn.Module) -> None:
......@@ -129,11 +129,12 @@ class Sequential(torch.nn.Module):
del self[idx]
return out
def __iadd__(self, other: Sequential) -> Sequential:
return self.extend(other)
def __iadd__(self, modules: Iterable[torch.nn.Modules]) -> Sequential:
return self.extend(modules)
def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential:
out = self.__class__(self._modules)
out = Sequential()
out.extend(self)
out.extend(modules)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment