Commit 203b8f90 authored by myhloli's avatar myhloli
Browse files

fix(device): enable MPS support and fix related issues

- Add MPS support for Apple Silicon devices
- Implement empty_cache() for MPS devices
- Set PYTORCH_ENABLE_MPS_FALLBACK environment variable
- Adjust MFR model device allocation for MPS
parent 0f401645
......@@ -12,4 +12,6 @@ def clean_memory(device='cuda'):
import torch_npu
if torch_npu.npu.is_available():
torch_npu.npu.empty_cache()
elif str(device).startswith("mps"):
torch.mps.empty_cache()
gc.collect()
\ No newline at end of file
......@@ -92,6 +92,8 @@ class CustomPEKModel:
import torch_npu
os.environ['FLAGS_npu_jit_compile'] = '0'
os.environ['FLAGS_use_stride_kernel'] = '0'
elif str(self.device).startswith("mps"):
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
......@@ -119,11 +121,12 @@ class CustomPEKModel:
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device,
device='cpu' if str(self.device).startswith("mps") else self.device,
)
# 初始化layout模型
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment