Unverified Commit d8e56857 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix CGO import error (#4010)

parent e5b4bf1a
......@@ -13,6 +13,7 @@ websockets
filelock
prettytable
dataclasses ; python_version < "3.7"
typing_extensions ; python_version < "3.8"
numpy < 1.19.4 ; sys_platform == "win32"
numpy < 1.20 ; sys_platform != "win32" and python_version < "3.7"
numpy ; sys.platform != "win32" and python_version >= "3.7"
......
......@@ -13,7 +13,7 @@ from torchvision.datasets import CIFAR10
from blocks import ShuffleNetBlock, ShuffleXceptionBlock
from nn_meter import load_latency_predictors
from nn_meter import load_latency_predictor
class ShuffleNetV2(nn.Module):
......@@ -142,7 +142,7 @@ class LatencyFilter:
if reverse is `False`, then the model returns `True` when `latency < threshold`,
else otherwisse
"""
self.predictors = load_latency_predictors(predictor, predictor_version)
self.predictors = load_latency_predictor(predictor, predictor_version)
self.threshold = threshold
def __call__(self, ir_model):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Literal
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
@dataclass
class GPUDevice:
......
......@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
import nni
try:
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
from .cgo import trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment