Unverified Commit d3e3d71b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix typehint of make divisible (#4862)

parent c8d1ed9d
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import math
from typing import Optional, Callable, List, Tuple, Iterator, cast
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
import torch
import nni.retiarii.nn.pytorch as nn
......@@ -12,7 +12,17 @@ from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
def make_divisible(v, divisor, min_val=None):
@overload
def make_divisible(v: Union[int, float], divisor, min_val=None) -> int:
...
@overload
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float]], divisor, min_val=None) -> nn.ChoiceOf[int]:
...
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float], int, float], divisor, min_val=None) -> nn.MaybeChoice[int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment