clean_memory.py 479 Bytes
Newer Older
1
2
3
4
5
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc


6
7
8
9
10
11
12
def clean_memory(device='cuda'):
    if device == 'cuda':
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
    elif str(device).startswith("npu"):
        import torch_npu
13
14
        if torch_npu.npu.is_available():
            torch_npu.npu.empty_cache()
15
16
    elif str(device).startswith("mps"):
        torch.mps.empty_cache()
17
    gc.collect()