Unverified Commit 8d2215be authored by dosemeion's avatar dosemeion Committed by GitHub
Browse files

add pathlib (#4395)

parent 3935a099
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import queue
import logging
import copy
import logging
from pathlib import Path
import queue
import torch
import torch.nn as nn
......@@ -63,7 +65,7 @@ class ModelSpeedup:
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
if isinstance(masks_file, str) and os.path.exists(masks_file):
if isinstance(masks_file, (str, Path)) and Path(masks_file).exists():
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment