Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
04dd01a2
You need to sign in or sign up before continuing.
Commit
04dd01a2
authored
Jul 05, 2023
by
mzr1996
Browse files
Update configs and code
parent
c94cc943
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
61 additions
and
212 deletions
+61
-212
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py
+4
-0
configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py
configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py
+4
-0
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py
+4
-0
configs/datasets/XCOPA/XCOPA_ppl.py
configs/datasets/XCOPA/XCOPA_ppl.py
+4
-0
configs/datasets/agieval/agieval_gen.py
configs/datasets/agieval/agieval_gen.py
+4
-0
configs/datasets/flores/flores_gen.py
configs/datasets/flores/flores_gen.py
+4
-0
configs/datasets/mmlu/mmlu_ppl.py
configs/datasets/mmlu/mmlu_ppl.py
+4
-0
configs/datasets/summscreen/summscreen_gen.py
configs/datasets/summscreen/summscreen_gen.py
+4
-0
configs/datasets/winogrande/winogrande_gen.py
configs/datasets/winogrande/winogrande_gen.py
+4
-0
docs/en/advanced_guides/new_model.md
docs/en/advanced_guides/new_model.md
+1
-0
docs/zh_cn/user_guides/evaluation.md
docs/zh_cn/user_guides/evaluation.md
+1
-0
opencompass/datasets/strategyqa.py
opencompass/datasets/strategyqa.py
+14
-0
opencompass/models/__init__.py
opencompass/models/__init__.py
+6
-0
opencompass/models/xunfei_api.py
opencompass/models/xunfei_api.py
+0
-212
opencompass/runners/__init__.py
opencompass/runners/__init__.py
+3
-0
No files found.
configs/datasets/CLUE_CMRC/CLUE_CMRC_gen.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.CLUE_CMRC_gen_72a8d5
import
CMRC_datasets
# noqa: F401, F403
configs/datasets/FewCLUE_eprstmt/FewCLUE_eprstmt_ppl.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.FewCLUE_eprstmt_ppl_d3c387
import
eprstmt_datasets
# noqa: F401, F403
configs/datasets/SuperGLUE_AX_b/SuperGLUE_AX_b_ppl.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.SuperGLUE_AX_b_ppl_4bd960
import
AX_b_datasets
# noqa: F401, F403
configs/datasets/XCOPA/XCOPA_ppl.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.XCOPA_ppl_6215c4
import
XCOPA_datasets
# noqa: F401, F403
configs/datasets/agieval/agieval_gen.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.agieval_gen_dc7dae
import
agieval_datasets
# noqa: F401, F403
configs/datasets/flores/flores_gen.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.flores_gen_8eb9ca
import
flores_datasets
# noqa: F401, F403
configs/datasets/mmlu/mmlu_ppl.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.mmlu_ppl_c6bbe6
import
mmlu_datasets
# noqa: F401, F403
configs/datasets/summscreen/summscreen_gen.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.summscreen_gen_997ee2
import
summscreen_datasets
# noqa: F401, F403
configs/datasets/winogrande/winogrande_gen.py
0 → 100644
View file @
04dd01a2
from
mmengine.config
import
read_base
with
read_base
():
from
.winogrande_gen_c19d87
import
winogrande_datasets
# noqa: F401, F403
docs/en/advanced_guides/new_model.md
0 → 100644
View file @
04dd01a2
# New A Model
docs/zh_cn/user_guides/evaluation.md
0 → 100644
View file @
04dd01a2
# 评估策略
opencompass/datasets/strategyqa.py
0 → 100644
View file @
04dd01a2
from
opencompass.registry
import
TEXT_POSTPROCESSORS
@
TEXT_POSTPROCESSORS
.
register_module
(
'strategyqa'
)
def
strategyqa_pred_postprocess
(
text
:
str
)
->
str
:
text
=
text
.
split
(
'
\n\n
'
)[
0
]
strategyqa_pre
=
text
.
split
(
'So the answer is '
)[
-
1
].
strip
().
replace
(
'.'
,
''
)
return
strategyqa_pre
@
TEXT_POSTPROCESSORS
.
register_module
(
'strategyqa_dataset'
)
def
strategyqa_dataset_postprocess
(
text
:
str
)
->
str
:
return
'yes'
if
str
(
text
)
==
'True'
else
'no'
opencompass/models/__init__.py
0 → 100644
View file @
04dd01a2
from
.base
import
BaseModel
,
LMTemplateParser
# noqa
from
.base_api
import
APITemplateParser
,
BaseAPIModel
# noqa
from
.glm
import
GLM130B
# noqa: F401, F403
from
.huggingface
import
HuggingFace
# noqa: F401, F403
from
.huggingface
import
HuggingFaceCausalLM
# noqa: F401, F403
from
.openai_api
import
OpenAI
# noqa: F401
opencompass/models/xunfei_api.py
deleted
100644 → 0
View file @
c94cc943
import
json
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Union
from
opencompass.registry
import
MODELS
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
PromptType
=
Union
[
PromptList
,
str
]
@
MODELS
.
register_module
(
name
=
[
'XunFei'
])
class
XunFei
(
BaseAPIModel
):
"""Model wrapper around OpenAI-AllesAPIN.
Args:
path (str): The name of OpenAI's model.
max_seq_len (int): Unused here.
call_interval (float): The minimum time interval in seconds between two
calls to the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def
__init__
(
self
,
path
:
str
,
appid
:
str
,
api_secret
:
str
,
api_key
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
2
):
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
)
import
ssl
import
threading
from
urllib.parse
import
urlencode
,
urlparse
import
websocket
self
.
urlencode
=
urlencode
self
.
websocket
=
websocket
self
.
websocket
.
enableTrace
(
False
)
self
.
threading
=
threading
self
.
ssl
=
ssl
# weird auth keys
self
.
APISecret
=
api_secret
self
.
APIKey
=
api_key
self
.
appid
=
appid
self
.
hostname
=
urlparse
(
path
).
netloc
self
.
hostpath
=
urlparse
(
path
).
path
self
.
headers
=
{
'content-type'
:
'application/json'
,
}
def
get_url
(
self
):
from
datetime
import
datetime
from
time
import
mktime
from
wsgiref.handlers
import
format_date_time
cur_time
=
datetime
.
now
()
date
=
format_date_time
(
mktime
(
cur_time
.
timetuple
()))
tmp
=
f
'host:
{
self
.
hostname
}
\n
'
tmp
+=
'date: '
+
date
+
'
\n
'
tmp
+=
'GET '
+
self
.
hostpath
+
' HTTP/1.1'
import
hashlib
import
hmac
tmp_sha
=
hmac
.
new
(
self
.
APISecret
.
encode
(
'utf-8'
),
tmp
.
encode
(
'utf-8'
),
digestmod
=
hashlib
.
sha256
).
digest
()
import
base64
signature
=
base64
.
b64encode
(
tmp_sha
).
decode
(
encoding
=
'utf-8'
)
authorization_origin
=
(
f
'api_key="
{
self
.
APIKey
}
", '
'algorithm="hmac-sha256", '
'headers="host date request-line", '
f
'signature="
{
signature
}
"'
)
authorization
=
base64
.
b64encode
(
authorization_origin
.
encode
(
'utf-8'
)).
decode
(
encoding
=
'utf-8'
)
v
=
{
'authorization'
:
authorization
,
'date'
:
date
,
'host'
:
self
.
hostname
}
url
=
self
.
path
+
'?'
+
self
.
urlencode
(
v
)
return
url
def
generate
(
self
,
inputs
:
List
[
str
or
PromptList
],
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
)))
return
results
def
_generate
(
self
,
input
:
str
or
PromptList
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
"""Generate results given an input.
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
# FIXME: messages only contains the last input
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
messages
=
[]
# word_ctr = 0
# TODO: Implement truncation in PromptList
for
item
in
input
:
msg
=
{
'content'
:
item
[
'prompt'
]}
# if word_ctr >= self.max_seq_len:
# break
# if len(msg['content']) + word_ctr > self.max_seq_len:
# msg['content'] = msg['content'][word_ctr -
# self.max_seq_len:]
# word_ctr += len(msg['content'])
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'assistant'
messages
.
append
(
msg
)
# in case the word break results in even number of messages
# if len(messages) > 0 and len(messages) % 2 == 0:
# messages = messages[:-1]
data
=
{
'header'
:
{
'app_id'
:
self
.
appid
,
},
'parameter'
:
{
'chat'
:
{
'domain'
:
'general'
,
'max_tokens'
:
max_out_len
,
}
},
'payload'
:
{
'message'
:
{
'text'
:
messages
}
}
}
msg
=
''
err_code
=
None
err_data
=
None
content_received
=
self
.
threading
.
Event
()
def
on_open
(
ws
):
nonlocal
data
ws
.
send
(
json
.
dumps
(
data
))
def
on_message
(
ws
,
message
):
nonlocal
msg
,
err_code
,
err_data
,
content_received
err_data
=
json
.
loads
(
message
)
err_code
=
err_data
[
'header'
][
'code'
]
if
err_code
!=
0
:
content_received
.
set
()
ws
.
close
()
else
:
choices
=
err_data
[
'payload'
][
'choices'
]
status
=
choices
[
'status'
]
msg
+=
choices
[
'text'
][
0
][
'content'
]
if
status
==
2
:
content_received
.
set
()
ws
.
close
()
ws
=
self
.
websocket
.
WebSocketApp
(
self
.
get_url
(),
on_message
=
on_message
,
on_open
=
on_open
)
ws
.
appid
=
self
.
appid
ws
.
question
=
messages
[
-
1
][
'content'
]
for
_
in
range
(
self
.
retry
):
self
.
wait
()
ws
.
run_forever
(
sslopt
=
{
'cert_reqs'
:
self
.
ssl
.
CERT_NONE
})
content_received
.
wait
()
if
err_code
==
0
:
return
msg
.
strip
()
if
err_code
==
10013
:
return
err_data
[
'header'
][
'message'
]
raise
RuntimeError
(
f
'Code:
{
err_code
}
, data:
{
err_data
}
'
)
opencompass/runners/__init__.py
0 → 100644
View file @
04dd01a2
from
.dlc
import
*
# noqa: F401, F403
from
.local
import
*
# noqa: F401, F403
from
.slurm
import
*
# noqa: F401, F403
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment