Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
a94598d9
You need to sign in or sign up before continuing.
Unverified
Commit
a94598d9
authored
Dec 13, 2023
by
Hubert
Committed by
GitHub
Dec 13, 2023
Browse files
[Feat] update python action and slurm (#694)
parent
61303941
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
59 deletions
+89
-59
opencompass/lagent/actions/python_interpreter.py
opencompass/lagent/actions/python_interpreter.py
+67
-55
opencompass/runners/slurm.py
opencompass/runners/slurm.py
+11
-2
opencompass/runners/slurm_sequential.py
opencompass/runners/slurm_sequential.py
+11
-2
No files found.
opencompass/lagent/actions/python_interpreter.py
View file @
a94598d9
import
copy
import
copy
import
io
import
io
import
signal
import
multiprocessing
from
contextlib
import
redirect_stdout
from
contextlib
import
redirect_stdout
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
from
lagent.actions.base_action
import
BaseAction
from
lagent.actions.base_action
import
BaseAction
from
lagent.schema
import
ActionReturn
,
ActionStatusCode
from
lagent.schema
import
ActionReturn
,
ActionStatusCode
from
opencompass.datasets.mbpp
import
TimeOutException
,
swallow_io
,
time_limit
class
TimeoutError
(
Exception
):
pass
def
handler
(
signum
,
frame
):
raise
TimeoutError
()
class
GenericRuntime
:
class
GenericRuntime
:
...
@@ -90,30 +84,60 @@ class PythonInterpreter(BaseAction):
...
@@ -90,30 +84,60 @@ class PythonInterpreter(BaseAction):
self
.
answer_from_stdout
=
answer_from_stdout
self
.
answer_from_stdout
=
answer_from_stdout
self
.
timeout
=
timeout
self
.
timeout
=
timeout
def
__call__
(
self
,
command
:
str
)
->
ActionReturn
:
@
staticmethod
self
.
runtime
=
GenericRuntime
()
def
extract_code
(
command
:
str
)
->
str
:
signal
.
signal
(
signal
.
SIGALRM
,
handler
)
signal
.
alarm
(
self
.
timeout
)
try
:
tool_return
=
self
.
_call
(
command
)
except
TimeoutError
as
e
:
tool_return
=
ActionReturn
(
url
=
None
,
args
=
None
,
type
=
self
.
name
)
tool_return
.
errmsg
=
repr
(
e
)
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
finally
:
signal
.
alarm
(
0
)
return
tool_return
def
_call
(
self
,
command
:
str
)
->
ActionReturn
:
tool_return
=
ActionReturn
(
url
=
None
,
args
=
None
,
type
=
self
.
name
)
try
:
if
'```python'
in
command
:
if
'```python'
in
command
:
command
=
command
.
split
(
'```python'
)[
1
].
split
(
'```'
)[
0
]
command
=
command
.
split
(
'```python'
)[
1
].
split
(
'```'
)[
0
]
elif
'```'
in
command
:
elif
'```'
in
command
:
command
=
command
.
split
(
'```'
)[
1
].
split
(
'```'
)[
0
]
command
=
command
.
split
(
'```'
)[
1
].
split
(
'```'
)[
0
]
tool_return
.
args
=
dict
(
text
=
'```python
\n
'
+
command
+
'
\n
```'
)
command
=
command
.
split
(
'
\n
'
)
command
=
command
.
split
(
'
\n
'
)
return
command
def
__call__
(
self
,
command
:
str
)
->
ActionReturn
:
"""Execution function for running generation code.
Args:
command(str): Python code to be executed.
"""
extracted_command
=
self
.
extract_code
(
command
)
tool_return
=
ActionReturn
(
url
=
None
,
args
=
dict
(
text
=
command
,
extract_code
=
extracted_command
),
type
=
self
.
name
)
def
_execution
(
q
,
command
,
tool_return
):
try
:
with
swallow_io
():
# leave 1s for multiprocess
with
time_limit
(
self
.
timeout
-
1
):
res
=
self
.
_call
(
command
)
tool_return
.
result
=
dict
(
text
=
str
(
res
))
tool_return
.
state
=
ActionStatusCode
.
SUCCESS
except
TimeOutException
:
tool_return
.
errmsg
=
f
'Time out after
{
self
.
timeout
}
seconds.'
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
except
BaseException
as
e
:
tool_return
.
errmsg
=
f
'Failed.
{
e
}
.'
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
q
.
put
(
tool_return
)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q
=
multiprocessing
.
Queue
()
p
=
multiprocessing
.
Process
(
target
=
_execution
,
args
=
(
q
,
extracted_command
,
tool_return
))
p
.
start
()
p
.
join
(
timeout
=
self
.
timeout
)
if
p
.
is_alive
():
p
.
kill
()
# return timeout due to some unknown error
tool_return
.
errmsg
=
f
'Time out after
{
self
.
timeout
}
seconds.'
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
return
tool_return
return
q
.
get
()
def
_call
(
self
,
command
:
str
)
->
ActionReturn
:
self
.
runtime
=
GenericRuntime
()
if
self
.
answer_from_stdout
:
if
self
.
answer_from_stdout
:
program_io
=
io
.
StringIO
()
program_io
=
io
.
StringIO
()
with
redirect_stdout
(
program_io
):
with
redirect_stdout
(
program_io
):
...
@@ -129,16 +153,4 @@ class PythonInterpreter(BaseAction):
...
@@ -129,16 +153,4 @@ class PythonInterpreter(BaseAction):
else
:
else
:
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
[:
-
1
]))
self
.
runtime
.
exec_code
(
'
\n
'
.
join
(
command
[:
-
1
]))
res
=
True
res
=
True
except
Exception
as
e
:
return
res
tool_return
.
errmsg
=
repr
(
e
)
tool_return
.
type
=
self
.
name
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
return
tool_return
try
:
tool_return
.
result
=
dict
(
text
=
str
(
res
))
tool_return
.
state
=
ActionStatusCode
.
SUCCESS
except
Exception
as
e
:
tool_return
.
errmsg
=
repr
(
e
)
tool_return
.
type
=
self
.
name
tool_return
.
state
=
ActionStatusCode
.
API_ERROR
return
tool_return
opencompass/runners/slurm.py
View file @
a94598d9
...
@@ -4,7 +4,7 @@ import random
...
@@ -4,7 +4,7 @@ import random
import
subprocess
import
subprocess
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
mmengine
import
mmengine
from
mmengine.config
import
ConfigDict
from
mmengine.config
import
ConfigDict
...
@@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
...
@@ -31,6 +31,8 @@ class SlurmRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
...
@@ -41,13 +43,18 @@ class SlurmRunner(BaseRunner):
quotatype
:
str
=
None
,
quotatype
:
str
=
None
,
qos
:
str
=
None
,
qos
:
str
=
None
,
debug
:
bool
=
False
,
debug
:
bool
=
False
,
lark_bot_url
:
str
=
None
):
lark_bot_url
:
str
=
None
,
extra_command
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
(
task
=
task
,
debug
=
debug
,
lark_bot_url
=
lark_bot_url
)
super
().
__init__
(
task
=
task
,
debug
=
debug
,
lark_bot_url
=
lark_bot_url
)
self
.
max_num_workers
=
max_num_workers
self
.
max_num_workers
=
max_num_workers
self
.
retry
=
retry
self
.
retry
=
retry
self
.
partition
=
partition
self
.
partition
=
partition
self
.
quotatype
=
quotatype
self
.
quotatype
=
quotatype
self
.
qos
=
qos
self
.
qos
=
qos
if
not
extra_command
:
extra_command
=
[]
assert
isinstance
(
extra_command
,
list
)
self
.
extra_command
=
extra_command
def
launch
(
self
,
tasks
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Tuple
[
str
,
int
]]:
def
launch
(
self
,
tasks
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Tuple
[
str
,
int
]]:
"""Launch multiple tasks.
"""Launch multiple tasks.
...
@@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
...
@@ -101,6 +108,8 @@ class SlurmRunner(BaseRunner):
tmpl
+=
f
' --qos=
{
self
.
qos
}
'
tmpl
+=
f
' --qos=
{
self
.
qos
}
'
if
num_gpus
>
0
:
if
num_gpus
>
0
:
tmpl
+=
f
' --gres=gpu:
{
num_gpus
}
'
tmpl
+=
f
' --gres=gpu:
{
num_gpus
}
'
for
extra_cmd
in
self
.
extra_command
:
tmpl
+=
f
'
{
extra_cmd
}
'
tmpl
+=
f
" -N1 -J '
{
task_name
[:
512
]
}
'"
+
' {task_cmd}'
tmpl
+=
f
" -N1 -J '
{
task_name
[:
512
]
}
'"
+
' {task_cmd}'
get_cmd
=
partial
(
task
.
get_command
,
get_cmd
=
partial
(
task
.
get_command
,
cfg_path
=
param_file
,
cfg_path
=
param_file
,
...
...
opencompass/runners/slurm_sequential.py
View file @
a94598d9
...
@@ -6,7 +6,7 @@ import time
...
@@ -6,7 +6,7 @@ import time
import
traceback
import
traceback
from
functools
import
partial
from
functools
import
partial
from
multiprocessing
import
Pipe
,
Pool
from
multiprocessing
import
Pipe
,
Pool
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
mmengine
import
mmengine
from
mmengine.config
import
ConfigDict
from
mmengine.config
import
ConfigDict
...
@@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
...
@@ -45,6 +45,8 @@ class SlurmSequentialRunner(BaseRunner):
qos (str): Slurm quality of service. Defaults to None.
qos (str): Slurm quality of service. Defaults to None.
debug (bool): Whether to run in debug mode. Defaults to False.
debug (bool): Whether to run in debug mode. Defaults to False.
lark_bot_url (str): Lark bot url. Defaults to None.
lark_bot_url (str): Lark bot url. Defaults to None.
extra_command (List, optional): Extra slurm command.
For example ['-c 12', '-w node1']. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
...
@@ -56,7 +58,8 @@ class SlurmSequentialRunner(BaseRunner):
quotatype
:
str
=
None
,
quotatype
:
str
=
None
,
qos
:
str
=
None
,
qos
:
str
=
None
,
debug
:
bool
=
False
,
debug
:
bool
=
False
,
lark_bot_url
:
str
=
None
):
lark_bot_url
:
str
=
None
,
extra_command
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
(
task
=
task
,
debug
=
debug
,
lark_bot_url
=
lark_bot_url
)
super
().
__init__
(
task
=
task
,
debug
=
debug
,
lark_bot_url
=
lark_bot_url
)
self
.
max_num_workers
=
max_num_workers
self
.
max_num_workers
=
max_num_workers
self
.
retry
=
retry
self
.
retry
=
retry
...
@@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
...
@@ -64,6 +67,10 @@ class SlurmSequentialRunner(BaseRunner):
self
.
quotatype
=
quotatype
self
.
quotatype
=
quotatype
self
.
qos
=
qos
self
.
qos
=
qos
self
.
task_prefix
=
task_prefix
self
.
task_prefix
=
task_prefix
if
not
extra_command
:
extra_command
=
[]
assert
isinstance
(
extra_command
,
list
)
self
.
extra_command
=
extra_command
logger
=
get_logger
()
logger
=
get_logger
()
if
self
.
quotatype
in
[
'spot'
,
'auto'
]:
if
self
.
quotatype
in
[
'spot'
,
'auto'
]:
...
@@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
...
@@ -173,6 +180,8 @@ class SlurmSequentialRunner(BaseRunner):
tmpl
+=
f
' --qos=
{
self
.
qos
}
'
tmpl
+=
f
' --qos=
{
self
.
qos
}
'
if
num_gpus
>
0
:
if
num_gpus
>
0
:
tmpl
+=
f
' --gres=gpu:
{
num_gpus
}
'
tmpl
+=
f
' --gres=gpu:
{
num_gpus
}
'
for
extra_cmd
in
self
.
extra_command
:
tmpl
+=
f
'
{
extra_cmd
}
'
tmpl
+=
f
" -N1 -J '
{
task_name
[:
512
]
}
'"
+
' {task_cmd}'
tmpl
+=
f
" -N1 -J '
{
task_name
[:
512
]
}
'"
+
' {task_cmd}'
get_cmd
=
partial
(
task
.
get_command
,
get_cmd
=
partial
(
task
.
get_command
,
cfg_path
=
param_file
,
cfg_path
=
param_file
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment