Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
001e77fe
Unverified
Commit
001e77fe
authored
Feb 28, 2024
by
bittersweet1999
Committed by
GitHub
Feb 28, 2024
Browse files
[Feature] add support for gemini (#931)
* add gemini * add gemini * add gemini
parent
9afbfa36
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
277 additions
and
2 deletions
+277
-2
configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py
...s/subjective/alignbench/alignbench_judgeby_critiquellm.py
+2
-2
configs/models/gemini/gemini_pro.py
configs/models/gemini/gemini_pro.py
+23
-0
opencompass/models/__init__.py
opencompass/models/__init__.py
+1
-0
opencompass/models/gemini_api.py
opencompass/models/gemini_api.py
+251
-0
No files found.
configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py
View file @
001e77fe
...
...
@@ -14,8 +14,8 @@ subjective_all_sets = [
]
data_path
=
"data/subjective/alignment_bench"
alignment_bench_config_path
=
"data/subjective/alignment_bench/"
alignment_bench_config_name
=
'
config/
multi-dimension'
alignment_bench_config_path
=
"data/subjective/alignment_bench/
config
"
alignment_bench_config_name
=
'multi-dimension'
subjective_datasets
=
[]
...
...
configs/models/gemini/gemini_pro.py
0 → 100644
View file @
001e77fe
from
opencompass.models
import
Gemini
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'gemini'
,
type
=
Gemini
,
path
=
'gemini-pro'
,
key
=
'your keys'
,
# The key will be obtained from Environment, but you can write down your key here as well
url
=
"your url"
,
meta_template
=
api_meta_template
,
query_per_second
=
16
,
max_out_len
=
100
,
max_seq_len
=
2048
,
batch_size
=
1
,
temperature
=
1
,)
]
opencompass/models/__init__.py
View file @
001e77fe
...
...
@@ -7,6 +7,7 @@ from .base import BaseModel, LMTemplateParser # noqa
from
.base_api
import
APITemplateParser
,
BaseAPIModel
# noqa
from
.bytedance_api
import
ByteDance
# noqa: F401
from
.claude_api
import
Claude
# noqa: F401
from
.gemini_api
import
Gemini
,
GeminiAllesAPIN
# noqa: F401, F403
from
.glm
import
GLM130B
# noqa: F401, F403
from
.huggingface
import
HuggingFace
# noqa: F401, F403
from
.huggingface
import
HuggingFaceCausalLM
# noqa: F401, F403
...
...
opencompass/models/gemini_api.py
0 → 100644
View file @
001e77fe
# flake8: noqa: E501
import
json
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Union
import
requests
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
PromptType
=
Union
[
PromptList
,
str
,
float
]
class
Gemini
(
BaseAPIModel
):
"""Model wrapper around Gemini models.
Documentation:
Args:
path (str): The name of Gemini model.
e.g. `gemini-pro`
key (str): Authorization key.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
max_seq_len (int): Unused here.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def
__init__
(
self
,
key
:
str
,
path
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
2
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
float
=
10.0
,
):
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
)
self
.
url
=
f
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=
{
key
}
'
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
headers
=
{
'content-type'
:
'application/json'
,
}
def
generate
(
self
,
inputs
:
List
[
str
or
PromptList
],
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
)))
self
.
flush
()
return
results
def
_generate
(
self
,
input
:
str
or
PromptList
,
max_out_len
:
int
=
512
,
)
->
str
:
"""Generate results given an input.
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'parts'
:
[{
'text'
:
input
}]}]
else
:
messages
=
[]
system_prompt
=
None
for
item
in
input
:
if
item
[
'role'
]
==
'SYSTEM'
:
system_prompt
=
item
[
'prompt'
]
for
item
in
input
:
if
system_prompt
is
not
None
:
msg
=
{
'parts'
:
[{
'text'
:
system_prompt
+
'
\n
'
+
item
[
'prompt'
]
}]
}
else
:
msg
=
{
'parts'
:
[{
'text'
:
item
[
'prompt'
]}]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'role'
]
=
'user'
messages
.
append
(
msg
)
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'model'
messages
.
append
(
msg
)
elif
item
[
'role'
]
==
'SYSTEM'
:
pass
# model can be response with user and system
# when it comes with agent involved.
assert
msg
[
'role'
]
in
[
'user'
,
'system'
]
data
=
{
'model'
:
self
.
path
,
'contents'
:
messages
,
'safetySettings'
:
[
{
'category'
:
'HARM_CATEGORY_DANGEROUS_CONTENT'
,
'threshold'
:
'BLOCK_NONE'
},
{
'category'
:
'HARM_CATEGORY_HATE_SPEECH'
,
'threshold'
:
'BLOCK_NONE'
},
{
'category'
:
'HARM_CATEGORY_HARASSMENT'
,
'threshold'
:
'BLOCK_NONE'
},
{
'category'
:
'HARM_CATEGORY_DANGEROUS_CONTENT'
,
'threshold'
:
'BLOCK_NONE'
},
],
'generationConfig'
:
{
'candidate_count'
:
1
,
'temperature'
:
self
.
temperature
,
'maxOutputTokens'
:
2048
,
'topP'
:
self
.
top_p
,
'topK'
:
self
.
top_k
}
}
for
_
in
range
(
self
.
retry
):
self
.
wait
()
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
self
.
headers
,
data
=
json
.
dumps
(
data
))
try
:
response
=
raw_response
.
json
()
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
time
.
sleep
(
1
)
continue
if
raw_response
.
status_code
==
200
and
response
[
'msg'
]
==
'ok'
:
body
=
response
[
'body'
]
if
'candidates'
not
in
body
:
self
.
logger
.
error
(
response
)
else
:
if
'content'
not
in
body
[
'candidates'
][
0
]:
return
"Due to Google's restrictive policies, I am unable to respond to this question."
else
:
return
body
[
'candidates'
][
0
][
'content'
][
'parts'
][
0
][
'text'
].
strip
()
self
.
logger
.
error
(
response
[
'msg'
])
self
.
logger
.
error
(
response
)
time
.
sleep
(
1
)
raise
RuntimeError
(
'API call failed.'
)
class
GeminiAllesAPIN
(
Gemini
):
"""Model wrapper around Gemini models.
Documentation:
Args:
path (str): The name of Gemini model.
e.g. `gemini-pro`
key (str): Authorization key.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
max_seq_len (int): Unused here.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def
__init__
(
self
,
path
:
str
,
key
:
str
,
url
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
2
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
float
=
10.0
,
):
super
().
__init__
(
key
=
key
,
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
)
# Replace the url and headers into AllesApin
self
.
url
=
url
self
.
headers
=
{
'alles-apin-token'
:
key
,
'content-type'
:
'application/json'
,
}
def
generate
(
self
,
inputs
:
List
[
str
or
PromptList
],
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
return
super
().
generate
(
inputs
,
max_out_len
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment