Commit 2785f604 authored by myhloli's avatar myhloli
Browse files

fix: support NPU device in UnimernetModel initialization

parent 3cdcd76c
......@@ -21,7 +21,7 @@ class MathDataset(Dataset):
class UnimernetModel(object):
def __init__(self, weight_dir, _device_="cpu"):
from .unimernet_hf import UnimernetModel
if _device_.startswith("mps"):
if _device_.startswith("mps") or _device_.startswith("npu"):
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
else:
self.model = UnimernetModel.from_pretrained(weight_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment