Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
46c96457
Unverified
Commit
46c96457
authored
Jul 28, 2023
by
Haodong Duan
Committed by
GitHub
Jul 28, 2023
Browse files
[Feature] Allow explicitly setting the temperature for API model (#121)
* allow explicitly setting the temperature * update
parent
80ce18f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
11 deletions
+19
-11
opencompass/models/openai_api.py
opencompass/models/openai_api.py
+19
-11
No files found.
opencompass/models/openai_api.py
View file @
46c96457
...
@@ -12,6 +12,7 @@ from opencompass.utils.prompt import PromptList
...
@@ -12,6 +12,7 @@ from opencompass.utils.prompt import PromptList
from
.base_api
import
BaseAPIModel
from
.base_api
import
BaseAPIModel
PromptType
=
Union
[
PromptList
,
str
]
PromptType
=
Union
[
PromptList
,
str
]
OPENAI_API_BASE
=
'https://api.openai.com/v1/chat/completions'
@
MODELS
.
register_module
()
@
MODELS
.
register_module
()
...
@@ -40,21 +41,24 @@ class OpenAI(BaseAPIModel):
...
@@ -40,21 +41,24 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions.
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
'https://api.openai.com/v1/chat/completions'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
"""
"""
is_api
:
bool
=
True
is_api
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
path
:
str
,
path
:
str
,
max_seq_len
:
int
=
2048
,
max_seq_len
:
int
=
2048
,
query_per_second
:
int
=
1
,
query_per_second
:
int
=
1
,
retry
:
int
=
2
,
retry
:
int
=
2
,
key
:
Union
[
str
,
List
[
str
]]
=
'ENV'
,
key
:
Union
[
str
,
List
[
str
]]
=
'ENV'
,
org
:
Optional
[
Union
[
str
,
List
[
str
]]
]
=
None
,
org
:
Optional
[
Union
[
str
,
List
[
str
]]
]
=
None
,
meta_template
:
Optional
[
Dict
]
=
None
,
meta_template
:
Optional
[
Dict
]
=
None
,
openai_api_base
:
str
=
OPENAI_API_BASE
,
openai_api_base
:
str
=
'https://api.openai.com/v1/chat/completions'
temperature
:
Optional
[
float
]
=
None
):
):
# noqa
super
().
__init__
(
path
=
path
,
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
meta_template
=
meta_template
,
meta_template
=
meta_template
,
...
@@ -62,6 +66,7 @@ class OpenAI(BaseAPIModel):
...
@@ -62,6 +66,7 @@ class OpenAI(BaseAPIModel):
retry
=
retry
)
retry
=
retry
)
import
tiktoken
import
tiktoken
self
.
tiktoken
=
tiktoken
self
.
tiktoken
=
tiktoken
self
.
temperature
=
temperature
if
isinstance
(
key
,
str
):
if
isinstance
(
key
,
str
):
self
.
keys
=
[
os
.
getenv
(
'OPENAI_API_KEY'
)
if
key
==
'ENV'
else
key
]
self
.
keys
=
[
os
.
getenv
(
'OPENAI_API_KEY'
)
if
key
==
'ENV'
else
key
]
...
@@ -96,6 +101,9 @@ class OpenAI(BaseAPIModel):
...
@@ -96,6 +101,9 @@ class OpenAI(BaseAPIModel):
Returns:
Returns:
List[str]: A list of generated strings.
List[str]: A list of generated strings.
"""
"""
if
self
.
temperature
is
not
None
:
temperature
=
self
.
temperature
with
ThreadPoolExecutor
()
as
executor
:
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
executor
.
map
(
self
.
_generate
,
inputs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment