from contextlib import contextmanager import torch @contextmanager def torch_cuda_wrapper(): try: # replace cuda APIs with xpu APIs, this should work by default torch.cuda.Stream = torch.xpu.Stream torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.stream = torch.xpu.stream yield finally: pass