"...composable_kernel-1.git" did not exist on "1085794df3c6568832252ee7f2a06a72e488891d"
Unverified Commit a94598d9 authored by Hubert's avatar Hubert Committed by GitHub
Browse files

[Feat] update python action and slurm (#694)

parent 61303941
import copy import copy
import io import io
import signal import multiprocessing
from contextlib import redirect_stdout from contextlib import redirect_stdout
from typing import Any, Optional from typing import Any, Optional
from lagent.actions.base_action import BaseAction from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode from lagent.schema import ActionReturn, ActionStatusCode
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
class TimeoutError(Exception):
pass
def handler(signum, frame):
raise TimeoutError()
class GenericRuntime: class GenericRuntime:
...@@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction): ...@@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction):
self.answer_from_stdout = answer_from_stdout self.answer_from_stdout = answer_from_stdout
self.timeout = timeout self.timeout = timeout
@staticmethod
def extract_code(command: str) -> str:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
command = command.split('\n')
return command
def __call__(self, command: str) -> ActionReturn: def __call__(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime() """Execution function for running generation code.
signal.signal(signal.SIGALRM, handler)
signal.alarm(self.timeout) Args:
try: command(str): Python code to be executed.
tool_return = self._call(command) """
except TimeoutError as e: extracted_command = self.extract_code(command)
tool_return = ActionReturn(url=None, args=None, type=self.name) tool_return = ActionReturn(url=None,
tool_return.errmsg = repr(e) args=dict(text=command,
extract_code=extracted_command),
type=self.name)
def _execution(q, command, tool_return):
try:
with swallow_io():
# leave 1s for multiprocess
with time_limit(self.timeout - 1):
res = self._call(command)
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except TimeOutException:
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
except BaseException as e:
tool_return.errmsg = f'Failed. {e}.'
tool_return.state = ActionStatusCode.API_ERROR
q.put(tool_return)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q = multiprocessing.Queue()
p = multiprocessing.Process(target=_execution,
args=(q, extracted_command, tool_return))
p.start()
p.join(timeout=self.timeout)
if p.is_alive():
p.kill()
# return timeout due to some unknown error
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR tool_return.state = ActionStatusCode.API_ERROR
finally: return tool_return
signal.alarm(0) return q.get()
return tool_return
def _call(self, command: str) -> ActionReturn: def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name) self.runtime = GenericRuntime()
try: if self.answer_from_stdout:
if '```python' in command: program_io = io.StringIO()
command = command.split('```python')[1].split('```')[0] with redirect_stdout(program_io):
elif '```' in command:
command = command.split('```')[1].split('```')[0]
tool_return.args = dict(text='```python\n' + command + '\n```')
command = command.split('\n')
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command)) self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr) program_io.seek(0)
else: res = program_io.readlines()[-1]
self.runtime.exec_code('\n'.join(command[:-1])) elif self.answer_symbol:
res = True self.runtime.exec_code('\n'.join(command))
except Exception as e: res = self.runtime._global_vars[self.answer_symbol]
tool_return.errmsg = repr(e) elif self.answer_expr:
tool_return.type = self.name self.runtime.exec_code('\n'.join(command))
tool_return.state = ActionStatusCode.API_ERROR res = self.runtime.eval_code(self.answer_expr)
return tool_return else:
try: self.runtime.exec_code('\n'.join(command[:-1]))
tool_return.result = dict(text=str(res)) res = True
tool_return.state = ActionStatusCode.SUCCESS return res
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
import subprocess import subprocess
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import mmengine import mmengine
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
...@@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner): ...@@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None. qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False. debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None. lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
""" """
def __init__(self, def __init__(self,
...@@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner): ...@@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
quotatype: str = None, quotatype: str = None,
qos: str = None, qos: str = None,
debug: bool = False, debug: bool = False,
lark_bot_url: str = None): lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers self.max_num_workers = max_num_workers
self.retry = retry self.retry = retry
self.partition = partition self.partition = partition
self.quotatype = quotatype self.quotatype = quotatype
self.qos = qos self.qos = qos
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks. """Launch multiple tasks.
...@@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner): ...@@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
tmpl += f' --qos={self.qos}' tmpl += f' --qos={self.qos}'
if num_gpus > 0: if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}' tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
cfg_path=param_file, cfg_path=param_file,
......
...@@ -6,7 +6,7 @@ import time ...@@ -6,7 +6,7 @@ import time
import traceback import traceback
from functools import partial from functools import partial
from multiprocessing import Pipe, Pool from multiprocessing import Pipe, Pool
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import mmengine import mmengine
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
...@@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner): ...@@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None. qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False. debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None. lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
""" """
def __init__(self, def __init__(self,
...@@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner): ...@@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
quotatype: str = None, quotatype: str = None,
qos: str = None, qos: str = None,
debug: bool = False, debug: bool = False,
lark_bot_url: str = None): lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers self.max_num_workers = max_num_workers
self.retry = retry self.retry = retry
...@@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner): ...@@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
self.quotatype = quotatype self.quotatype = quotatype
self.qos = qos self.qos = qos
self.task_prefix = task_prefix self.task_prefix = task_prefix
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
logger = get_logger() logger = get_logger()
if self.quotatype in ['spot', 'auto']: if self.quotatype in ['spot', 'auto']:
...@@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner): ...@@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
tmpl += f' --qos={self.qos}' tmpl += f' --qos={self.qos}'
if num_gpus > 0: if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}' tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}' tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
cfg_path=param_file, cfg_path=param_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment