Unverified Commit 3756b607 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to models._utils (#2854)

* style: Added typing in models._utils

* fix: Removed non-necessary import

* fix: Removed type annotation of forward method

* refactor: Removed un-necessary import
parent 8a0434d2
...@@ -2,7 +2,7 @@ from collections import OrderedDict ...@@ -2,7 +2,7 @@ from collections import OrderedDict
import torch import torch
from torch import nn from torch import nn
from torch.jit.annotations import Dict from typing import Dict
class IntermediateLayerGetter(nn.ModuleDict): class IntermediateLayerGetter(nn.ModuleDict):
...@@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
"return_layers": Dict[str, str], "return_layers": Dict[str, str],
} }
def __init__(self, model, return_layers): def __init__(self, model: nn.Module, return_layers: Dict[str, str]):
if not set(return_layers).issubset([name for name, _ in model.named_children()]): if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model") raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers orig_return_layers = return_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment