Unverified Commit a94598d9 authored by Hubert's avatar Hubert Committed by GitHub
Browse files

[Feat] update python action and slurm (#694)

parent 61303941
import copy
import io
import signal
import multiprocessing
from contextlib import redirect_stdout
from typing import Any, Optional
from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode
class TimeoutError(Exception):
pass
def handler(signum, frame):
raise TimeoutError()
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
class GenericRuntime:
......@@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction):
self.answer_from_stdout = answer_from_stdout
self.timeout = timeout
@staticmethod
def extract_code(command: str) -> str:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
command = command.split('\n')
return command
def __call__(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
signal.signal(signal.SIGALRM, handler)
signal.alarm(self.timeout)
try:
tool_return = self._call(command)
except TimeoutError as e:
tool_return = ActionReturn(url=None, args=None, type=self.name)
tool_return.errmsg = repr(e)
"""Execution function for running generation code.
Args:
command(str): Python code to be executed.
"""
extracted_command = self.extract_code(command)
tool_return = ActionReturn(url=None,
args=dict(text=command,
extract_code=extracted_command),
type=self.name)
def _execution(q, command, tool_return):
try:
with swallow_io():
# leave 1s for multiprocess
with time_limit(self.timeout - 1):
res = self._call(command)
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except TimeOutException:
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
except BaseException as e:
tool_return.errmsg = f'Failed. {e}.'
tool_return.state = ActionStatusCode.API_ERROR
q.put(tool_return)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q = multiprocessing.Queue()
p = multiprocessing.Process(target=_execution,
args=(q, extracted_command, tool_return))
p.start()
p.join(timeout=self.timeout)
if p.is_alive():
p.kill()
# return timeout due to some unknown error
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
finally:
signal.alarm(0)
return tool_return
return tool_return
return q.get()
def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name)
try:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
tool_return.args = dict(text='```python\n' + command + '\n```')
command = command.split('\n')
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime = GenericRuntime()
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = True
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
try:
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = True
return res
......@@ -4,7 +4,7 @@ import random
import subprocess
import time
from functools import partial
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import mmengine
from mmengine.config import ConfigDict
......@@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
def __init__(self,
......@@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
quotatype: str = None,
qos: str = None,
debug: bool = False,
lark_bot_url: str = None):
lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
self.partition = partition
self.quotatype = quotatype
self.qos = qos
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
......@@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command,
cfg_path=param_file,
......
......@@ -6,7 +6,7 @@ import time
import traceback
from functools import partial
from multiprocessing import Pipe, Pool
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import mmengine
from mmengine.config import ConfigDict
......@@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
def __init__(self,
......@@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
quotatype: str = None,
qos: str = None,
debug: bool = False,
lark_bot_url: str = None):
lark_bot_url: str = None,
extra_command: Optional[List[str]] = None):
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
self.max_num_workers = max_num_workers
self.retry = retry
......@@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
self.quotatype = quotatype
self.qos = qos
self.task_prefix = task_prefix
if not extra_command:
extra_command = []
assert isinstance(extra_command, list)
self.extra_command = extra_command
logger = get_logger()
if self.quotatype in ['spot', 'auto']:
......@@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
tmpl += f' --qos={self.qos}'
if num_gpus > 0:
tmpl += f' --gres=gpu:{num_gpus}'
for extra_cmd in self.extra_command:
tmpl += f' {extra_cmd}'
tmpl += f" -N1 -J '{task_name[:512]}'" + ' {task_cmd}'
get_cmd = partial(task.get_command,
cfg_path=param_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment