Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
a94598d9
Unverified
Commit
a94598d9
authored
Dec 13, 2023
by
Hubert
Committed by
GitHub
Dec 13, 2023
Browse files
[Feat] update python action and slurm (#694)
parent
61303941
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
59 deletions
+89
-59
opencompass/lagent/actions/python_interpreter.py
opencompass/lagent/actions/python_interpreter.py
+67
-55
opencompass/runners/slurm.py
opencompass/runners/slurm.py
+11
-2
opencompass/runners/slurm_sequential.py
opencompass/runners/slurm_sequential.py
+11
-2
No files found.
opencompass/lagent/actions/python_interpreter.py
View file @
a94598d9
import
copy
import
io
import
signal
import
multiprocessing
from
contextlib
import
redirect_stdout
from
typing
import
Any
,
Optional
from
lagent.actions.base_action
import
BaseAction
from
lagent.schema
import
ActionReturn
,
ActionStatusCode
class
TimeoutError
(
Exception
):
pass
def
handler
(
signum
,
frame
):
raise
TimeoutError
()
from
opencompass.datasets.mbpp
import
TimeOutException
,
swallow_io
,
time_limit
class
GenericRuntime
:
...
...
@@ -90,55 +84,73 @@ class PythonInterpreter(BaseAction):
self
.
answer_from_stdout
=
answer_from_stdout
self
.
timeout
=
timeout
@
staticmethod
def
extract_code
(
command
:
str
)
->
str
:
if
'```python'
in
command
:
command
=
command
.
split
(
'```python'
)[
1
].
split
(
'```'
)[
0
]
elif
'```'
in
command
:
command
=
command
.
split
(
'```'
)[
1
].
split
(
'```'
)[
0
]
command
=
command
.
split
(
'
\n
'
)
return
command
def
__call__
(
self
,
command
:
str
)
->
ActionReturn
:
self
.
runtime
=
GenericRuntime
()
signal
.
signal
(
signal
.
SIGALRM
,
handler
)
signal
.
alarm
(
self
.
timeout
)
try
:
tool_return
=
self
.
_call
(
command
)
except
TimeoutError
as
e
:
tool_return
=
ActionReturn
(
url
=
None
,
args
=
None
,
type
=
self
.
name
)
tool_return
.
errmsg
=
repr
(
e
)
"""Execution function for running generation code.
Args:
command(str): Python code to be executed.
"""
extracted_command
=
self
.
extract_code
(
command
)
tool_return
=
ActionReturn
(
url
=
None
,
args
=
dict
(
text
=
command
,
extract_code
=
extracted_command
),
type
=
self
.
name
)
def
_execution
(
q
,
command
,
tool_return
):
try
:
with
swallow_io
():
# leave 1s for multiprocess
with
time_limit
(
self
.
timeout
-
1
):
res
=
self
.
_call
(
command
)
tool_return
.
result
=
dict
(
text
=
str
(
res
))
tool_return
.
state
=
ActionStatusCode
.
SUCCESS
except
TimeOutException
:
tool_return
.
errmsg
=
f
'Time out after
{
self
.
timeout
}
seconds.'
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
except
BaseException
as
e
:
tool_return
.
errmsg
=
f
'Failed.
{
e
}
.'
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
q
.
put
(
tool_return
)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q
=
multiprocessing
.
Queue
()
p
=
multiprocessing
.
Process
(
target
=
_execution
,
args
=
(
q
,
extracted_command
,
tool_return
))
p
.
start
()
p
.
join
(
timeout
=
self
.
timeout
)
if
p
.
is_alive
():
p
.
kill
()
# return timeout due to some unknown error
tool_return
.
errmsg
=
f
'Time out after
{
self
.
timeout
}
seconds.'
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
finally
:
signal
.
alarm
(
0
)
return
tool_return
return
tool_return
return
q
.
get
()
def
_call
(
self
,
command
:
str
)
->
ActionReturn
:
tool_return
=
ActionReturn
(
url
=
None
,
args
=
None
,
type
=
self
.
name
)
try
:
if
'```python'
in
command
:
command
=
command
.
split
(
'```python'
)[
1
].
split
(
'```'
)[
0
]
elif
'```'
in
command
:
command
=
command
.
split
(
'```'
)[
1
].
split
(
'```'
)[
0
]
tool_return
.
args
=
dict
(
text
=
'```python
\n
'
+
command
+
'
\n
```'
)
command
=
command
.
split
(
'
\n
'
)
if
self
.
answer_from_stdout
:
program_io
=
io
.
StringIO
()
with
redirect_stdout
(
program_io
):
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
))
program_io
.
seek
(
0
)
res
=
program_io
.
readlines
()[
-
1
]
elif
self
.
answer_symbol
:
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
))
res
=
self
.
runtime
.
_global_vars
[
self
.
answer_symbol
]
elif
self
.
answer_expr
:
self
.
runtime
=
GenericRuntime
()
if
self
.
answer_from_stdout
:
program_io
=
io
.
StringIO
()
with
redirect_stdout
(
program_io
):
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
))
res
=
self
.
runtime
.
eval_code
(
self
.
answer_expr
)
else
:
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
[:
-
1
]))
res
=
True
except
Exception
as
e
:
tool_return
.
errmsg
=
repr
(
e
)
tool_return
.
type
=
self
.
name
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
return
tool_return
try
:
tool_return
.
result
=
dict
(
text
=
str
(
res
))
tool_return
.
state
=
ActionStatusCode
.
SUCCESS
except
Exception
as
e
:
tool_return
.
errmsg
=
repr
(
e
)
tool_return
.
type
=
self
.
name
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
return
tool_return
program_io
.
seek
(
0
)
res
=
program_io
.
readlines
()[
-
1
]
elif
self
.
answer_symbol
:
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
))
res
=
self
.
runtime
.
_global_vars
[
self
.
answer_symbol
]
elif
self
.
answer_expr
:
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
))
res
=
self
.
runtime
.
eval_code
(
self
.
answer_expr
)
else
:
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
[:
-
1
]))
res
=
True
return
res
opencompass/runners/slurm.py
View file @
a94598d9
...
...
@@ -4,7 +4,7 @@ import random
import
subprocess
import
time
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
mmengine
from
mmengine.config
import
ConfigDict
...
...
@@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
def
__init__
(
self
,
...
...
@@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
quotatype
:
str
=
None
,
qos
:
str
=
None
,
debug
:
bool
=
False
,
lark_bot_url
:
str
=
None
):
lark_bot_url
:
str
=
None
,
extra_command
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
(
task
=
task
,
debug
=
debug
,
lark_bot_url
=
lark_bot_url
)
self
.
max_num_workers
=
max_num_workers
self
.
retry
=
retry
self
.
partition
=
partition
self
.
quotatype
=
quotatype
self
.
qos
=
qos
if
not
extra_command
:
extra_command
=
[]
assert
isinstance
(
extra_command
,
list
)
self
.
extra_command
=
extra_command
def
launch
(
self
,
tasks
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Tuple
[
str
,
int
]]:
"""Launch multiple tasks.
...
...
@@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
tmpl
+=
f
' --qos=
{
self
.
qos
}
'
if
num_gpus
>
0
:
tmpl
+=
f
' --gres=gpu:
{
num_gpus
}
'
for
extra_cmd
in
self
.
extra_command
:
tmpl
+=
f
'
{
extra_cmd
}
'
tmpl
+=
f
" -N1 -J '
{
task_name
[:
512
]
}
'"
+
' {task_cmd}'
get_cmd
=
partial
(
task
.
get_command
,
cfg_path
=
param_file
,
...
...
opencompass/runners/slurm_sequential.py
View file @
a94598d9
...
...
@@ -6,7 +6,7 @@ import time
import
traceback
from
functools
import
partial
from
multiprocessing
import
Pipe
,
Pool
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
mmengine
from
mmengine.config
import
ConfigDict
...
...
@@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
def
__init__
(
self
,
...
...
@@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
quotatype
:
str
=
None
,
qos
:
str
=
None
,
debug
:
bool
=
False
,
lark_bot_url
:
str
=
None
):
lark_bot_url
:
str
=
None
,
extra_command
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
(
task
=
task
,
debug
=
debug
,
lark_bot_url
=
lark_bot_url
)
self
.
max_num_workers
=
max_num_workers
self
.
retry
=
retry
...
...
@@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
self
.
quotatype
=
quotatype
self
.
qos
=
qos
self
.
task_prefix
=
task_prefix
if
not
extra_command
:
extra_command
=
[]
assert
isinstance
(
extra_command
,
list
)
self
.
extra_command
=
extra_command
logger
=
get_logger
()
if
self
.
quotatype
in
[
'spot'
,
'auto'
]:
...
...
@@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
tmpl
+=
f
' --qos=
{
self
.
qos
}
'
if
num_gpus
>
0
:
tmpl
+=
f
' --gres=gpu:
{
num_gpus
}
'
for
extra_cmd
in
self
.
extra_command
:
tmpl
+=
f
'
{
extra_cmd
}
'
tmpl
+=
f
" -N1 -J '
{
task_name
[:
512
]
}
'"
+
' {task_cmd}'
get_cmd
=
partial
(
task
.
get_command
,
cfg_path
=
param_file
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment