utils.py 428 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from contextlib import contextmanager

import torch


@contextmanager
def torch_cuda_wrapper():
    try:
        # replace cuda APIs with xpu APIs, this should work by default
        torch.cuda.Stream = torch.xpu.Stream
        torch.cuda.default_stream = torch.xpu.current_stream
        torch.cuda.current_stream = torch.xpu.current_stream
        torch.cuda.stream = torch.xpu.stream
        yield
    finally:
        pass