Unverified Commit 3756b607 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to models._utils (#2854)

* style: Added typing in models._utils

* fix: Removed non-necessary import

* fix: Removed type annotation of forward method

* refactor: Removed un-necessary import
parent 8a0434d2
......@@ -2,7 +2,7 @@ from collections import OrderedDict
import torch
from torch import nn
from torch.jit.annotations import Dict
from typing import Dict
class IntermediateLayerGetter(nn.ModuleDict):
......@@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
"return_layers": Dict[str, str],
}
def __init__(self, model, return_layers):
def __init__(self, model: nn.Module, return_layers: Dict[str, str]):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment