Commit 203b8f90 authored by myhloli's avatar myhloli
Browse files

fix(device): enable MPS support and fix related issues

- Add MPS support for Apple Silicon devices
- Implement empty_cache() for MPS devices
- Set PYTORCH_ENABLE_MPS_FALLBACK environment variable
- Adjust MFR model device allocation for MPS
parent 0f401645
...@@ -12,4 +12,6 @@ def clean_memory(device='cuda'): ...@@ -12,4 +12,6 @@ def clean_memory(device='cuda'):
import torch_npu import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
torch_npu.npu.empty_cache() torch_npu.npu.empty_cache()
elif str(device).startswith("mps"):
torch.mps.empty_cache()
gc.collect() gc.collect()
\ No newline at end of file
...@@ -92,6 +92,8 @@ class CustomPEKModel: ...@@ -92,6 +92,8 @@ class CustomPEKModel:
import torch_npu import torch_npu
os.environ['FLAGS_npu_jit_compile'] = '0' os.environ['FLAGS_npu_jit_compile'] = '0'
os.environ['FLAGS_use_stride_kernel'] = '0' os.environ['FLAGS_use_stride_kernel'] = '0'
elif str(self.device).startswith("mps"):
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
logger.info('using device: {}'.format(self.device)) logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get( models_dir = kwargs.get(
...@@ -119,11 +121,12 @@ class CustomPEKModel: ...@@ -119,11 +121,12 @@ class CustomPEKModel:
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name]) os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
) )
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')) mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device=self.device, device='cpu' if str(self.device).startswith("mps") else self.device,
) )
# 初始化layout模型 # 初始化layout模型
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment