"googlemock/vscode:/vscode.git/clone" did not exist on "e588eb1ff9ff6598666279b737b27f983156ad85"
Unverified Commit 33a3d02f authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Update Sequential container to handle changes in module base class (#1028)



* Update sequential container constructor to handle modules in plain dicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid initializing Sequential with dicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 238df4ce
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Sequential container for fusible operations.""" """Sequential container for fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
from typing import Optional from typing import Optional
...@@ -39,7 +38,7 @@ class Sequential(torch.nn.Module): ...@@ -39,7 +38,7 @@ class Sequential(torch.nn.Module):
self._module_groups = None self._module_groups = None
# Add modules # Add modules
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], dict):
for key, module in args[0].items(): for key, module in args[0].items():
self.add_module(key, module) self.add_module(key, module)
else: else:
...@@ -82,8 +81,9 @@ class Sequential(torch.nn.Module): ...@@ -82,8 +81,9 @@ class Sequential(torch.nn.Module):
) -> Sequential | torch.nn.Module: ) -> Sequential | torch.nn.Module:
keys = self._get_keys_by_idx(idx) keys = self._get_keys_by_idx(idx)
if isinstance(idx, slice): if isinstance(idx, slice):
modules = OrderedDict((str(i), self._modules[key]) for i, key in enumerate(keys)) out = Sequential()
return self.__class__(modules) out.extend(self._modules[key] for key in keys)
return out
return self._modules[keys[0]] return self._modules[keys[0]]
def __setitem__(self, idx: int, module: torch.nn.Module) -> None: def __setitem__(self, idx: int, module: torch.nn.Module) -> None:
...@@ -129,11 +129,12 @@ class Sequential(torch.nn.Module): ...@@ -129,11 +129,12 @@ class Sequential(torch.nn.Module):
del self[idx] del self[idx]
return out return out
def __iadd__(self, other: Sequential) -> Sequential: def __iadd__(self, modules: Iterable[torch.nn.Modules]) -> Sequential:
return self.extend(other) return self.extend(modules)
def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential:
out = self.__class__(self._modules) out = Sequential()
out.extend(self)
out.extend(modules) out.extend(modules)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment