Unverified Commit dd39baf7 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[XPU] Fix xpu model runner call torch.cuda APIs (#25011)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 43a62c51
...@@ -45,8 +45,12 @@ def _torch_cuda_wrapper(): ...@@ -45,8 +45,12 @@ def _torch_cuda_wrapper():
self.synchronize = lambda: None self.synchronize = lambda: None
try: try:
# replace cuda Event with xpu Event, this should work by default # replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Event = torch.xpu.Event torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
yield yield
finally: finally:
# if anything goes wrong, just patch it with a placeholder # if anything goes wrong, just patch it with a placeholder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment