clean_memory.py 435 Bytes
Newer Older
1
2
3
4
5
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc


6
7
8
9
10
11
12
13
14
15
def clean_memory(device='cuda'):
    if device == 'cuda':
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
    elif str(device).startswith("npu"):
        import torch_npu
        if torch.npu.is_available():
            torch_npu.empty_cache()
            torch_npu.ipc_collect()
16
    gc.collect()