Unverified Commit 70fcdda6 authored by Zhenhua Han's avatar Zhenhua Han Committed by GitHub
Browse files

CGO execution engine handles missing GPU indices in RemoteMachineConfig (#4270)

parent e428db54
...@@ -81,6 +81,7 @@ To enable CGO execution engine, you need to follow these steps: ...@@ -81,6 +81,7 @@ To enable CGO execution engine, you need to follow these steps:
# ... # ...
# server configuration in rm_conf # server configuration in rm_conf
rm_conf.gpu_indices = [0, 1, 2, 3] # gpu_indices must be set in RemoteMachineConfig for CGO execution engine
config.training_service.machine_list = [rm_conf] config.training_service.machine_list = [rm_conf]
exp.run(config, 8099) exp.run(config, 8099)
......
...@@ -219,7 +219,8 @@ class RetiariiExperiment(Experiment): ...@@ -219,7 +219,8 @@ class RetiariiExperiment(Experiment):
elif self.config.execution_engine == 'cgo': elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine from ..execution.cgo_engine import CGOExecutionEngine
# assert self.config.trial_gpu_number==1, "trial_gpu_number must be 1 to use CGOExecutionEngine" assert self.config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None assert self.config.batch_waiting_time is not None
devices = self._construct_devices() devices = self._construct_devices()
engine = CGOExecutionEngine(devices, engine = CGOExecutionEngine(devices,
...@@ -273,11 +274,10 @@ class RetiariiExperiment(Experiment): ...@@ -273,11 +274,10 @@ class RetiariiExperiment(Experiment):
devices = [] devices = []
if hasattr(self.config.training_service, 'machine_list'): if hasattr(self.config.training_service, 'machine_list'):
for machine in self.config.training_service.machine_list: for machine in self.config.training_service.machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
for gpu_idx in machine.gpu_indices: for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx)) devices.append(GPUDevice(machine.host, gpu_idx))
else:
for gpu_idx in self.config.training_service.gpu_indices:
devices.append(GPUDevice('local', gpu_idx))
return devices return devices
def _create_dispatcher(self): def _create_dispatcher(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment