Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
2cee870e
"...composable_kernel-1.git" did not exist on "1a0cd5d160dfbe107a454f975a26599fc6daddd4"
Commit
2cee870e
authored
May 28, 2024
by
huangwb
Browse files
add tgi eval support
parent
2337da18
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
688 additions
and
0 deletions
+688
-0
configs/tgi/eval_llama2_13b_tgi.py
configs/tgi/eval_llama2_13b_tgi.py
+33
-0
configs/tgi/eval_llama2_7b_chat_tgi.py
configs/tgi/eval_llama2_7b_chat_tgi.py
+27
-0
configs/tgi/eval_llama2_7b_tgi.py
configs/tgi/eval_llama2_7b_tgi.py
+32
-0
configs/tgi/eval_llama3_8b_tgi.py
configs/tgi/eval_llama3_8b_tgi.py
+33
-0
configs/tgi/eval_qwen1.5_14b_chat_tgi.py
configs/tgi/eval_qwen1.5_14b_chat_tgi.py
+27
-0
configs/tgi/eval_qwen1.5_32b_chat_tgi.py
configs/tgi/eval_qwen1.5_32b_chat_tgi.py
+27
-0
configs/tgi/eval_qwen1.5_7b_chat_tgi.py
configs/tgi/eval_qwen1.5_7b_chat_tgi.py
+27
-0
opencompass/models/__init__.py
opencompass/models/__init__.py
+2
-0
opencompass/models/tgi_base_api.py
opencompass/models/tgi_base_api.py
+191
-0
opencompass/models/tgi_chat_api.py
opencompass/models/tgi_chat_api.py
+289
-0
No files found.
configs/tgi/eval_llama2_13b_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
# from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/llama2_13b/'
from
opencompass.models
import
TGIBASEAPI
models
=
[
dict
(
abbr
=
'llama2_13b_tgi'
,
path
=
'/models/llama-2-13B'
,
type
=
TGIBASEAPI
,
url
=
'http://localhost:3001/generate'
,
meta_template
=
None
,
batch_size
=
32
,
rate_per_worker
=
32
,
retry
=
4
,
generation_kwargs
=
dict
(
do_sample
=
False
,
ignore_eos
=
False
,
max_new_tokens
=
100
,
temperature
=
1
,
top_k
=
1
,
top_n
=
0.8
),
),
]
\ No newline at end of file
configs/tgi/eval_llama2_7b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/llama2_7b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'llama2_7b_chat'
,
type
=
TGICHATAPI
,
path
=
'/data/models/Llama-2-7b-chat-hf'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
configs/tgi/eval_llama2_7b_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
# from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/llama2_7b/'
from
opencompass.models
import
TGIBASEAPI
models
=
[
dict
(
abbr
=
'llama2_7b_tgi'
,
type
=
TGIBASEAPI
,
url
=
'http://localhost:3001/generate'
,
meta_template
=
None
,
batch_size
=
32
,
rate_per_worker
=
32
,
retry
=
4
,
generation_kwargs
=
dict
(
do_sample
=
False
,
ignore_eos
=
False
,
max_new_tokens
=
100
,
temperature
=
1
,
top_k
=
1
,
top_n
=
0.8
),
),
]
\ No newline at end of file
configs/tgi/eval_llama3_8b_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
# from ..datasets.ARC_c.ARC_c_gen_1e0de5 import ARC_c_datasets
# from ..datasets.ARC_e.ARC_e_gen_1e0de5 import ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/Meta-Llama-3-8B/'
from
opencompass.models
import
TGIBASEAPI
models
=
[
dict
(
abbr
=
'llama3_8b_tgi'
,
path
=
'/data/models/Meta-Llama-3-8B'
,
type
=
TGIBASEAPI
,
url
=
'http://localhost:3001/generate'
,
meta_template
=
None
,
batch_size
=
32
,
rate_per_worker
=
32
,
retry
=
4
,
generation_kwargs
=
dict
(
do_sample
=
False
,
ignore_eos
=
False
,
max_new_tokens
=
100
,
temperature
=
1
,
top_k
=
1
,
top_n
=
0.8
),
),
]
\ No newline at end of file
configs/tgi/eval_qwen1.5_14b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/qwen1.5_14b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'qwen1.5_14b_chat'
,
type
=
TGICHATAPI
,
path
=
'/models/Qwen1.5-14B-Chat'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
configs/tgi/eval_qwen1.5_32b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/qwen1.5_32b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'qwen1.5_32b_chat'
,
type
=
TGICHATAPI
,
path
=
'/models/Qwen1.5-32B-Chat'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
configs/tgi/eval_qwen1.5_7b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/qwen1.5_7b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'qwen1.5_7b_chat'
,
type
=
TGICHATAPI
,
path
=
'/models/Qwen1.5-7B-Chat'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
opencompass/models/__init__.py
View file @
2cee870e
...
@@ -33,6 +33,8 @@ from .turbomind import TurboMindModel # noqa: F401
...
@@ -33,6 +33,8 @@ from .turbomind import TurboMindModel # noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.unigpt_api
import
UniGPT
# noqa: F401
from
.unigpt_api
import
UniGPT
# noqa: F401
from
.vllm
import
VLLM
# noqa: F401
from
.vllm
import
VLLM
# noqa: F401
from
.tgi_chat_api
import
TGICHATAPI
from
.tgi_base_api
import
TGIBASEAPI
from
.xunfei_api
import
XunFei
# noqa: F401
from
.xunfei_api
import
XunFei
# noqa: F401
from
.yayi_api
import
Yayi
# noqa: F401
from
.yayi_api
import
Yayi
# noqa: F401
from
.zhipuai_api
import
ZhiPuAI
# noqa: F401
from
.zhipuai_api
import
ZhiPuAI
# noqa: F401
...
...
opencompass/models/tgi_base_api.py
0 → 100644
View file @
2cee870e
import
json
import
os
import
re
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
,
Union
import
json
import
re
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
import
numpy
as
np
import
requests
from
opencompass.registry
import
MODELS
from
opencompass.utils.logging
import
get_logger
from
.base
import
BaseModel
from
.base_api
import
TokenBucket
class
TGIBASEAPI
(
BaseModel
):
is_api
:
bool
=
True
def
__init__
(
self
,
path
:
str
=
'LightllmAPI'
,
url
:
str
=
'http://localhost:3001/generate'
,
meta_template
:
Optional
[
Dict
]
=
None
,
rate_per_worker
:
int
=
2
,
retry
:
int
=
2
,
generation_kwargs
:
Optional
[
Dict
]
=
dict
(),
):
super
().
__init__
(
path
=
path
,
meta_template
=
meta_template
,
generation_kwargs
=
generation_kwargs
)
self
.
logger
=
get_logger
()
self
.
url
=
url
self
.
retry
=
retry
self
.
generation_kwargs
=
generation_kwargs
self
.
max_out_len
=
self
.
generation_kwargs
.
get
(
'max_new_tokens'
,
1024
)
self
.
meta_template
=
meta_template
self
.
token_bucket
=
TokenBucket
(
rate_per_worker
,
False
)
def
generate
(
self
,
inputs
:
List
[
str
],
max_out_len
:
int
,
**
kwargs
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
self
.
max_out_len
]
*
len
(
inputs
)))
return
results
def
_generate
(
self
,
input
:
str
,
max_out_len
:
int
)
->
str
:
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
wait
()
header
=
{
'content-type'
:
'application/json'
}
try
:
data
=
dict
(
inputs
=
input
,
parameters
=
self
.
generation_kwargs
)
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
header
,
data
=
json
.
dumps
(
data
))
except
requests
.
ConnectionError
:
self
.
logger
.
error
(
'Got connection error, retrying...'
)
continue
try
:
response
=
raw_response
.
json
()
generated_text
=
response
[
'generated_text'
]
if
isinstance
(
generated_text
,
list
):
generated_text
=
generated_text
[
0
]
return
generated_text
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
except
KeyError
:
self
.
logger
.
error
(
f
'KeyError. Response:
{
str
(
response
)
}
'
)
max_num_retries
+=
1
raise
RuntimeError
(
'Calling LightllmAPI failed after retrying for '
f
'
{
max_num_retries
}
times. Check the logs for '
'details.'
)
def
get_ppl
(
self
,
inputs
:
List
[
str
],
max_out_len
:
int
,
**
kwargs
)
->
List
[
float
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_get_ppl
,
inputs
,
[
self
.
max_out_len
]
*
len
(
inputs
)))
return
np
.
array
(
results
)
def
_get_ppl
(
self
,
input
:
str
,
max_out_len
:
int
)
->
float
:
max_num_retries
=
0
if
max_out_len
is
None
:
max_out_len
=
1
while
max_num_retries
<
self
.
retry
:
self
.
wait
()
header
=
{
'content-type'
:
'application/json'
}
try
:
data
=
dict
(
inputs
=
input
,
parameters
=
self
.
generation_kwargs
)
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
header
,
data
=
json
.
dumps
(
data
))
except
requests
.
ConnectionError
:
self
.
logger
.
error
(
'Got connection error, retrying...'
)
continue
try
:
response
=
raw_response
.
json
()
assert
(
'prompt_token_ids'
in
response
and
'prompt_logprobs'
in
response
),
f
'prompt_token_ids and prompt_logprobs
\
must be in the output.
\
Please consider adding
\
--return_all_prompt_logprobs argument
\
when starting lightllm service. Response:
{
str
(
response
)
}
'
prompt_token_ids
=
response
[
'prompt_token_ids'
][
1
:]
prompt_logprobs
=
[
item
[
1
]
for
item
in
response
[
'prompt_logprobs'
]
]
logprobs
=
[
item
[
str
(
token_id
)]
for
token_id
,
item
in
zip
(
prompt_token_ids
,
prompt_logprobs
)
]
if
len
(
logprobs
)
==
0
:
return
0.0
ce_loss
=
-
sum
(
logprobs
)
/
len
(
logprobs
)
return
ce_loss
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
max_num_retries
+=
1
raise
RuntimeError
(
'Calling LightllmAPI failed after retrying for '
f
'
{
max_num_retries
}
times. Check the logs for '
'details.'
)
def
wait
(
self
):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return
self
.
token_bucket
.
get_token
()
def
get_token_len
(
self
,
prompt
:
str
)
->
int
:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
english_parts
=
re
.
findall
(
r
'[A-Za-z0-9]+'
,
prompt
)
chinese_parts
=
re
.
findall
(
r
'[\u4e00-\u9FFF]+'
,
prompt
)
# Count English words
english_count
=
sum
(
len
(
part
.
split
())
for
part
in
english_parts
)
# Count Chinese words
chinese_count
=
sum
(
len
(
part
)
for
part
in
chinese_parts
)
return
english_count
+
chinese_count
\ No newline at end of file
opencompass/models/tgi_chat_api.py
0 → 100644
View file @
2cee870e
import
json
import
os
import
re
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
,
Union
import
jieba
import
requests
from
opencompass.registry
import
MODELS
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
from
transformers
import
AutoConfig
,
AutoTokenizer
from
transformers.models.qwen2
import
Qwen2Tokenizer
from
transformers.models.llama
import
LlamaTokenizer
PromptType
=
Union
[
PromptList
,
str
]
OPENAI_API_BASE
=
'http://localhost:3000/v1/chat/completions'
class
TGICHATAPI
(
BaseAPIModel
):
"""Model wrapper around OpenAI's models.
Args:
path (str): The name of OpenAI's model.
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
key (str or List[str]): OpenAI key(s). In particular, when it
is set to "ENV", the key will be fetched from the environment
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
list, the keys will be used in round-robin manner. Defaults to
'ENV'.
org (str or List[str], optional): OpenAI organization(s). If not
specified, OpenAI uses the default organization bound to each API
key. If specified, the orgs will be posted with each request in
round-robin manner. Defaults to None.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
"""
is_api
:
bool
=
True
def
__init__
(
self
,
path
:
str
,
max_seq_len
:
int
=
4096
,
query_per_second
:
int
=
1
,
rpm_verbose
:
bool
=
False
,
retry
:
int
=
2
,
org
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
meta_template
:
Optional
[
Dict
]
=
None
,
openai_api_base
:
str
=
OPENAI_API_BASE
,
mode
:
str
=
'none'
,
logprobs
:
Optional
[
bool
]
=
False
,
top_logprobs
:
Optional
[
int
]
=
None
,
temperature
:
Optional
[
float
]
=
None
):
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
meta_template
=
meta_template
,
query_per_second
=
query_per_second
,
rpm_verbose
=
rpm_verbose
,
retry
=
retry
)
self
.
temperature
=
temperature
assert
mode
in
[
'none'
,
'front'
,
'mid'
,
'rear'
]
self
.
mode
=
mode
self
.
logprobs
=
logprobs
self
.
top_logprobs
=
top_logprobs
self
.
url
=
openai_api_base
self
.
path
=
path
# self.tokenizer = AutoTokenizer.from_pretrained(path)
# self.tokenizer = LlamaTokenizer.from_pretrained(path,padding_side="left",truncation_side="left",)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
path
,
padding_side
=
"left"
,
truncation_side
=
"left"
,)
# self.tokenizer = Qwen2Tokenizer.from_pretrained(path,padding_side="left",truncation_side="left",)
def
generate
(
self
,
inputs
:
List
[
PromptType
],
max_out_len
:
int
=
512
,
temperature
:
float
=
0.7
,
**
kwargs
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""
if
self
.
temperature
is
not
None
:
temperature
=
self
.
temperature
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
),
[
temperature
]
*
len
(
inputs
)))
return
results
def
_generate
(
self
,
input
:
PromptType
,
max_out_len
:
int
,
temperature
:
float
)
->
str
:
"""Generate results given a list of inputs.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
# max num token for gpt-3.5-turbo is 4097
context_window
=
4096
if
'32k'
in
self
.
path
:
context_window
=
32768
elif
'16k'
in
self
.
path
:
context_window
=
16384
elif
'gpt-4'
in
self
.
path
:
context_window
=
8192
# will leave 100 tokens as prompt buffer, triggered if input is str
if
isinstance
(
input
,
str
)
and
self
.
mode
!=
'none'
:
context_window
=
self
.
max_seq_len
input
=
self
.
bin_trim
(
input
,
context_window
-
100
-
max_out_len
)
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
messages
=
[]
for
item
in
input
:
msg
=
{
'content'
:
item
[
'prompt'
]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'assistant'
elif
item
[
'role'
]
==
'SYSTEM'
:
msg
[
'role'
]
=
'system'
messages
.
append
(
msg
)
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len
=
min
(
max_out_len
,
context_window
-
self
.
get_token_len
(
str
(
input
))
-
100
)
if
max_out_len
<=
0
:
return
''
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
wait
()
header
=
{
'content-type'
:
'application/json'
,}
try
:
data
=
dict
(
model
=
'tgi'
,
messages
=
messages
,
max_tokens
=
max_out_len
,
n
=
1
,
logprobs
=
self
.
logprobs
,
top_logprobs
=
self
.
top_logprobs
,
stop
=
None
,
temperature
=
temperature
,
)
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
header
,
data
=
json
.
dumps
(
data
))
except
requests
.
ConnectionError
:
self
.
logger
.
error
(
'Got connection error, retrying...'
)
continue
try
:
response
=
raw_response
.
json
()
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
continue
self
.
logger
.
debug
(
str
(
response
))
try
:
if
self
.
logprobs
:
return
response
[
'choices'
]
else
:
return
response
[
'choices'
][
0
][
'message'
][
'content'
].
strip
()
except
KeyError
:
if
'error'
in
response
:
if
response
[
'error'
][
'code'
]
==
'rate_limit_exceeded'
:
time
.
sleep
(
10
)
self
.
logger
.
warn
(
'Rate limit exceeded, retrying...'
)
continue
# elif response['error']['code'] == 'insufficient_quota':
# self.invalid_keys.add(key)
# self.logger.warn(f'insufficient_quota key: {key}')
# continue
elif
response
[
'error'
][
'code'
]
==
'invalid_prompt'
:
self
.
logger
.
warn
(
'Invalid prompt:'
,
str
(
input
))
return
''
self
.
logger
.
error
(
'Find error message in response: '
,
str
(
response
[
'error'
]))
max_num_retries
+=
1
raise
RuntimeError
(
'Calling OpenAI failed after retrying for '
f
'
{
max_num_retries
}
times. Check the logs for '
'details.'
)
def
get_token_len
(
self
,
prompt
:
str
)
->
int
:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
# enc = self.tiktoken.encoding_for_model(self.path)
return
len
(
self
.
tokenizer
.
encode
(
prompt
))
def
bin_trim
(
self
,
prompt
:
str
,
num_token
:
int
)
->
str
:
"""Get a suffix of prompt which is no longer than num_token tokens.
Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.
Returns:
str: The trimmed prompt.
"""
token_len
=
self
.
get_token_len
(
prompt
)
if
token_len
<=
num_token
:
return
prompt
pattern
=
re
.
compile
(
r
'[\u4e00-\u9fa5]'
)
if
pattern
.
search
(
prompt
):
words
=
list
(
jieba
.
cut
(
prompt
,
cut_all
=
False
))
sep
=
''
else
:
words
=
prompt
.
split
(
' '
)
sep
=
' '
l
,
r
=
1
,
len
(
words
)
while
l
+
2
<
r
:
mid
=
(
l
+
r
)
//
2
if
self
.
mode
==
'front'
:
cur_prompt
=
sep
.
join
(
words
[
-
mid
:])
elif
self
.
mode
==
'mid'
:
cur_prompt
=
sep
.
join
(
words
[:
mid
])
+
sep
.
join
(
words
[
-
mid
:])
elif
self
.
mode
==
'rear'
:
cur_prompt
=
sep
.
join
(
words
[:
mid
])
if
self
.
get_token_len
(
cur_prompt
)
<=
num_token
:
l
=
mid
# noqa: E741
else
:
r
=
mid
if
self
.
mode
==
'front'
:
prompt
=
sep
.
join
(
words
[
-
l
:])
elif
self
.
mode
==
'mid'
:
prompt
=
sep
.
join
(
words
[:
l
])
+
sep
.
join
(
words
[
-
l
:])
elif
self
.
mode
==
'rear'
:
prompt
=
sep
.
join
(
words
[:
l
])
return
prompt
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment