"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e861793ab710f7a617225610d52c20c339c8def6"
Unverified Commit 8d2215be authored by dosemeion's avatar dosemeion Committed by GitHub
Browse files

add pathlib (#4395)

parent 3935a099
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import os
import queue
import logging
import copy import copy
import logging
from pathlib import Path
import queue
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -63,7 +65,7 @@ class ModelSpeedup: ...@@ -63,7 +65,7 @@ class ModelSpeedup:
# load the mask tensor to the same device with the dummy_input # load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered # self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules # masks of the others modules
if isinstance(masks_file, str) and os.path.exists(masks_file): if isinstance(masks_file, (str, Path)) and Path(masks_file).exists():
self.masks = torch.load( self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device)) masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict): elif isinstance(masks_file, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment