Unverified Commit cebc22f3 authored by Chenguang Li's avatar Chenguang Li Committed by GitHub
Browse files

[Misc]Replace `cuda` hard code with `current_platform` in Ray (#14668)


Signed-off-by: default avatarnoemotiovon <757486878@qq.com>
parent 6c6dcd86
...@@ -87,9 +87,8 @@ try: ...@@ -87,9 +87,8 @@ try:
# TODO(swang): This is needed right now because Ray Compiled Graph # TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's # executes on a background thread, so we need to reset torch's
# current device. # current device.
import torch
if not self.compiled_dag_cuda_device_set: if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device) current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
output = self.worker._execute_model_spmd(execute_model_req, output = self.worker._execute_model_spmd(execute_model_req,
...@@ -113,8 +112,7 @@ try: ...@@ -113,8 +112,7 @@ try:
# Not needed # Not needed
pass pass
else: else:
import torch current_platform.set_device(self.worker.device)
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment