"test/ut/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "0a8fbbed5e4a8939b0b3094297b9c594d3ba5797"
Unverified Commit d8e56857 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix CGO import error (#4010)

parent e5b4bf1a
...@@ -13,6 +13,7 @@ websockets ...@@ -13,6 +13,7 @@ websockets
filelock filelock
prettytable prettytable
dataclasses ; python_version < "3.7" dataclasses ; python_version < "3.7"
typing_extensions ; python_version < "3.8"
numpy < 1.19.4 ; sys_platform == "win32" numpy < 1.19.4 ; sys_platform == "win32"
numpy < 1.20 ; sys_platform != "win32" and python_version < "3.7" numpy < 1.20 ; sys_platform != "win32" and python_version < "3.7"
numpy ; sys.platform != "win32" and python_version >= "3.7" numpy ; sys.platform != "win32" and python_version >= "3.7"
......
...@@ -13,7 +13,7 @@ from torchvision.datasets import CIFAR10 ...@@ -13,7 +13,7 @@ from torchvision.datasets import CIFAR10
from blocks import ShuffleNetBlock, ShuffleXceptionBlock from blocks import ShuffleNetBlock, ShuffleXceptionBlock
from nn_meter import load_latency_predictors from nn_meter import load_latency_predictor
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
...@@ -142,7 +142,7 @@ class LatencyFilter: ...@@ -142,7 +142,7 @@ class LatencyFilter:
if reverse is `False`, then the model returns `True` when `latency < threshold`, if reverse is `False`, then the model returns `True` when `latency < threshold`,
else otherwisse else otherwisse
""" """
self.predictors = load_latency_predictors(predictor, predictor_version) self.predictors = load_latency_predictor(predictor, predictor_version)
self.threshold = threshold self.threshold = threshold
def __call__(self, ir_model): def __call__(self, ir_model):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
@dataclass @dataclass
class GPUDevice: class GPUDevice:
......
...@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader ...@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
import nni import nni
try: try:
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer from .cgo import trainer as cgo_trainer
cgo_import_failed = False cgo_import_failed = False
except ImportError: except ImportError:
cgo_import_failed = True cgo_import_failed = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment