Unverified Commit d3e3d71b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix typehint of make divisible (#4862)

parent c8d1ed9d
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import math import math
from typing import Optional, Callable, List, Tuple, Iterator, cast from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
...@@ -12,7 +12,17 @@ from .utils.fixed import FixedFactory ...@@ -12,7 +12,17 @@ from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
def make_divisible(v, divisor, min_val=None): @overload
def make_divisible(v: Union[int, float], divisor, min_val=None) -> int:
...
@overload
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float]], divisor, min_val=None) -> nn.ChoiceOf[int]:
...
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float], int, float], divisor, min_val=None) -> nn.MaybeChoice[int]:
""" """
This function is taken from the original tf repo. This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8 It ensures that all layers have a channel number that is divisible by 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment