Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
2cee870e
Commit
2cee870e
authored
May 28, 2024
by
huangwb
Browse files
add tgi eval support
parent
2337da18
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
688 additions
and
0 deletions
+688
-0
configs/tgi/eval_llama2_13b_tgi.py
configs/tgi/eval_llama2_13b_tgi.py
+33
-0
configs/tgi/eval_llama2_7b_chat_tgi.py
configs/tgi/eval_llama2_7b_chat_tgi.py
+27
-0
configs/tgi/eval_llama2_7b_tgi.py
configs/tgi/eval_llama2_7b_tgi.py
+32
-0
configs/tgi/eval_llama3_8b_tgi.py
configs/tgi/eval_llama3_8b_tgi.py
+33
-0
configs/tgi/eval_qwen1.5_14b_chat_tgi.py
configs/tgi/eval_qwen1.5_14b_chat_tgi.py
+27
-0
configs/tgi/eval_qwen1.5_32b_chat_tgi.py
configs/tgi/eval_qwen1.5_32b_chat_tgi.py
+27
-0
configs/tgi/eval_qwen1.5_7b_chat_tgi.py
configs/tgi/eval_qwen1.5_7b_chat_tgi.py
+27
-0
opencompass/models/__init__.py
opencompass/models/__init__.py
+2
-0
opencompass/models/tgi_base_api.py
opencompass/models/tgi_base_api.py
+191
-0
opencompass/models/tgi_chat_api.py
opencompass/models/tgi_chat_api.py
+289
-0
No files found.
configs/tgi/eval_llama2_13b_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
# from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/llama2_13b/'
from
opencompass.models
import
TGIBASEAPI
models
=
[
dict
(
abbr
=
'llama2_13b_tgi'
,
path
=
'/models/llama-2-13B'
,
type
=
TGIBASEAPI
,
url
=
'http://localhost:3001/generate'
,
meta_template
=
None
,
batch_size
=
32
,
rate_per_worker
=
32
,
retry
=
4
,
generation_kwargs
=
dict
(
do_sample
=
False
,
ignore_eos
=
False
,
max_new_tokens
=
100
,
temperature
=
1
,
top_k
=
1
,
top_n
=
0.8
),
),
]
\ No newline at end of file
configs/tgi/eval_llama2_7b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/llama2_7b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'llama2_7b_chat'
,
type
=
TGICHATAPI
,
path
=
'/data/models/Llama-2-7b-chat-hf'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
configs/tgi/eval_llama2_7b_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
# from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/llama2_7b/'
from
opencompass.models
import
TGIBASEAPI
models
=
[
dict
(
abbr
=
'llama2_7b_tgi'
,
type
=
TGIBASEAPI
,
url
=
'http://localhost:3001/generate'
,
meta_template
=
None
,
batch_size
=
32
,
rate_per_worker
=
32
,
retry
=
4
,
generation_kwargs
=
dict
(
do_sample
=
False
,
ignore_eos
=
False
,
max_new_tokens
=
100
,
temperature
=
1
,
top_k
=
1
,
top_n
=
0.8
),
),
]
\ No newline at end of file
configs/tgi/eval_llama3_8b_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
# from ..datasets.ARC_c.ARC_c_gen_1e0de5 import ARC_c_datasets
# from ..datasets.ARC_e.ARC_e_gen_1e0de5 import ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/Meta-Llama-3-8B/'
from
opencompass.models
import
TGIBASEAPI
models
=
[
dict
(
abbr
=
'llama3_8b_tgi'
,
path
=
'/data/models/Meta-Llama-3-8B'
,
type
=
TGIBASEAPI
,
url
=
'http://localhost:3001/generate'
,
meta_template
=
None
,
batch_size
=
32
,
rate_per_worker
=
32
,
retry
=
4
,
generation_kwargs
=
dict
(
do_sample
=
False
,
ignore_eos
=
False
,
max_new_tokens
=
100
,
temperature
=
1
,
top_k
=
1
,
top_n
=
0.8
),
),
]
\ No newline at end of file
configs/tgi/eval_qwen1.5_14b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/qwen1.5_14b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'qwen1.5_14b_chat'
,
type
=
TGICHATAPI
,
path
=
'/models/Qwen1.5-14B-Chat'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
configs/tgi/eval_qwen1.5_32b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/qwen1.5_32b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'qwen1.5_32b_chat'
,
type
=
TGICHATAPI
,
path
=
'/models/Qwen1.5-32B-Chat'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
configs/tgi/eval_qwen1.5_7b_chat_tgi.py
0 → 100644
View file @
2cee870e
from
mmengine.config
import
read_base
with
read_base
():
from
..datasets.ARC_c.ARC_c_gen_1e0de5
import
ARC_c_datasets
from
..datasets.ARC_e.ARC_e_gen_1e0de5
import
ARC_e_datasets
from
..datasets.ceval.ceval_gen_5f30c7
import
ceval_datasets
from
..summarizers.example
import
summarizer
datasets
=
sum
([
v
for
k
,
v
in
locals
().
items
()
if
k
.
endswith
(
"_datasets"
)
or
k
==
'datasets'
],
[])
work_dir
=
'./outputs/qwen1.5_7b_chat/'
from
opencompass.models
import
TGICHATAPI
api_meta_template
=
dict
(
round
=
[
dict
(
role
=
'HUMAN'
,
api_role
=
'HUMAN'
),
dict
(
role
=
'BOT'
,
api_role
=
'BOT'
,
generate
=
True
),
],
)
models
=
[
dict
(
abbr
=
'qwen1.5_7b_chat'
,
type
=
TGICHATAPI
,
path
=
'/models/Qwen1.5-7B-Chat'
,
meta_template
=
api_meta_template
,
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
4096
,
batch_size
=
8
),
]
\ No newline at end of file
opencompass/models/__init__.py
View file @
2cee870e
...
@@ -33,6 +33,8 @@ from .turbomind import TurboMindModel # noqa: F401
...
@@ -33,6 +33,8 @@ from .turbomind import TurboMindModel # noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.unigpt_api
import
UniGPT
# noqa: F401
from
.unigpt_api
import
UniGPT
# noqa: F401
from
.vllm
import
VLLM
# noqa: F401
from
.vllm
import
VLLM
# noqa: F401
from
.tgi_chat_api
import
TGICHATAPI
from
.tgi_base_api
import
TGIBASEAPI
from
.xunfei_api
import
XunFei
# noqa: F401
from
.xunfei_api
import
XunFei
# noqa: F401
from
.yayi_api
import
Yayi
# noqa: F401
from
.yayi_api
import
Yayi
# noqa: F401
from
.zhipuai_api
import
ZhiPuAI
# noqa: F401
from
.zhipuai_api
import
ZhiPuAI
# noqa: F401
...
...
opencompass/models/tgi_base_api.py
0 → 100644
View file @
2cee870e
import
json
import
os
import
re
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
,
Union
import
json
import
re
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
import
numpy
as
np
import
requests
from
opencompass.registry
import
MODELS
from
opencompass.utils.logging
import
get_logger
from
.base
import
BaseModel
from
.base_api
import
TokenBucket
class
TGIBASEAPI
(
BaseModel
):
is_api
:
bool
=
True
def
__init__
(
self
,
path
:
str
=
'LightllmAPI'
,
url
:
str
=
'http://localhost:3001/generate'
,
meta_template
:
Optional
[
Dict
]
=
None
,
rate_per_worker
:
int
=
2
,
retry
:
int
=
2
,
generation_kwargs
:
Optional
[
Dict
]
=
dict
(),
):
super
().
__init__
(
path
=
path
,
meta_template
=
meta_template
,
generation_kwargs
=
generation_kwargs
)
self
.
logger
=
get_logger
()
self
.
url
=
url
self
.
retry
=
retry
self
.
generation_kwargs
=
generation_kwargs
self
.
max_out_len
=
self
.
generation_kwargs
.
get
(
'max_new_tokens'
,
1024
)
self
.
meta_template
=
meta_template
self
.
token_bucket
=
TokenBucket
(
rate_per_worker
,
False
)
def
generate
(
self
,
inputs
:
List
[
str
],
max_out_len
:
int
,
**
kwargs
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
self
.
max_out_len
]
*
len
(
inputs
)))
return
results
def
_generate
(
self
,
input
:
str
,
max_out_len
:
int
)
->
str
:
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
wait
()
header
=
{
'content-type'
:
'application/json'
}
try
:
data
=
dict
(
inputs
=
input
,
parameters
=
self
.
generation_kwargs
)
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
header
,
data
=
json
.
dumps
(
data
))
except
requests
.
ConnectionError
:
self
.
logger
.
error
(
'Got connection error, retrying...'
)
continue
try
:
response
=
raw_response
.
json
()
generated_text
=
response
[
'generated_text'
]
if
isinstance
(
generated_text
,
list
):
generated_text
=
generated_text
[
0
]
return
generated_text
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
except
KeyError
:
self
.
logger
.
error
(
f
'KeyError. Response:
{
str
(
response
)
}
'
)
max_num_retries
+=
1
raise
RuntimeError
(
'Calling LightllmAPI failed after retrying for '
f
'
{
max_num_retries
}
times. Check the logs for '
'details.'
)
def
get_ppl
(
self
,
inputs
:
List
[
str
],
max_out_len
:
int
,
**
kwargs
)
->
List
[
float
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_get_ppl
,
inputs
,
[
self
.
max_out_len
]
*
len
(
inputs
)))
return
np
.
array
(
results
)
def
_get_ppl
(
self
,
input
:
str
,
max_out_len
:
int
)
->
float
:
max_num_retries
=
0
if
max_out_len
is
None
:
max_out_len
=
1
while
max_num_retries
<
self
.
retry
:
self
.
wait
()
header
=
{
'content-type'
:
'application/json'
}
try
:
data
=
dict
(
inputs
=
input
,
parameters
=
self
.
generation_kwargs
)
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
header
,
data
=
json
.
dumps
(
data
))
except
requests
.
ConnectionError
:
self
.
logger
.
error
(
'Got connection error, retrying...'
)
continue
try
:
response
=
raw_response
.
json
()
assert
(
'prompt_token_ids'
in
response
and
'prompt_logprobs'
in
response
),
f
'prompt_token_ids and prompt_logprobs
\
must be in the output.
\
Please consider adding
\
--return_all_prompt_logprobs argument
\
when starting lightllm service. Response:
{
str
(
response
)
}
'
prompt_token_ids
=
response
[
'prompt_token_ids'
][
1
:]
prompt_logprobs
=
[
item
[
1
]
for
item
in
response
[
'prompt_logprobs'
]
]
logprobs
=
[
item
[
str
(
token_id
)]
for
token_id
,
item
in
zip
(
prompt_token_ids
,
prompt_logprobs
)
]
if
len
(
logprobs
)
==
0
:
return
0.0
ce_loss
=
-
sum
(
logprobs
)
/
len
(
logprobs
)
return
ce_loss
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
max_num_retries
+=
1
raise
RuntimeError
(
'Calling LightllmAPI failed after retrying for '
f
'
{
max_num_retries
}
times. Check the logs for '
'details.'
)
def
wait
(
self
):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return
self
.
token_bucket
.
get_token
()
def
get_token_len
(
self
,
prompt
:
str
)
->
int
:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
english_parts
=
re
.
findall
(
r
'[A-Za-z0-9]+'
,
prompt
)
chinese_parts
=
re
.
findall
(
r
'[\u4e00-\u9FFF]+'
,
prompt
)
# Count English words
english_count
=
sum
(
len
(
part
.
split
())
for
part
in
english_parts
)
# Count Chinese words
chinese_count
=
sum
(
len
(
part
)
for
part
in
chinese_parts
)
return
english_count
+
chinese_count
\ No newline at end of file
opencompass/models/tgi_chat_api.py
0 → 100644
View file @
2cee870e
import
json
import
os
import
re
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
,
Union
import
jieba
import
requests
from
opencompass.registry
import
MODELS
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
from
transformers
import
AutoConfig
,
AutoTokenizer
from
transformers.models.qwen2
import
Qwen2Tokenizer
from
transformers.models.llama
import
LlamaTokenizer
PromptType
=
Union
[
PromptList
,
str
]
OPENAI_API_BASE
=
'http://localhost:3000/v1/chat/completions'
class
TGICHATAPI
(
BaseAPIModel
):
"""Model wrapper around OpenAI's models.
Args:
path (str): The name of OpenAI's model.
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
key (str or List[str]): OpenAI key(s). In particular, when it
is set to "ENV", the key will be fetched from the environment
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
list, the keys will be used in round-robin manner. Defaults to
'ENV'.
org (str or List[str], optional): OpenAI organization(s). If not
specified, OpenAI uses the default organization bound to each API
key. If specified, the orgs will be posted with each request in
round-robin manner. Defaults to None.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
"""
is_api
:
bool
=
True
def
__init__
(
self
,
path
:
str
,
max_seq_len
:
int
=
4096
,
query_per_second
:
int
=
1
,
rpm_verbose
:
bool
=
False
,
retry
:
int
=
2
,
org
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
meta_template
:
Optional
[
Dict
]
=
None
,
openai_api_base
:
str
=
OPENAI_API_BASE
,
mode
:
str
=
'none'
,
logprobs
:
Optional
[
bool
]
=
False
,
top_logprobs
:
Optional
[
int
]
=
None
,
temperature
:
Optional
[
float
]
=
None
):
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
meta_template
=
meta_template
,
query_per_second
=
query_per_second
,
rpm_verbose
=
rpm_verbose
,
retry
=
retry
)
self
.
temperature
=
temperature
assert
mode
in
[
'none'
,
'front'
,
'mid'
,
'rear'
]
self
.
mode
=
mode
self
.
logprobs
=
logprobs
self
.
top_logprobs
=
top_logprobs
self
.
url
=
openai_api_base
self
.
path
=
path
# self.tokenizer = AutoTokenizer.from_pretrained(path)
# self.tokenizer = LlamaTokenizer.from_pretrained(path,padding_side="left",truncation_side="left",)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
path
,
padding_side
=
"left"
,
truncation_side
=
"left"
,)
# self.tokenizer = Qwen2Tokenizer.from_pretrained(path,padding_side="left",truncation_side="left",)
def
generate
(
self
,
inputs
:
List
[
PromptType
],
max_out_len
:
int
=
512
,
temperature
:
float
=
0.7
,
**
kwargs
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.
Returns:
List[str]: A list of generated strings.
"""
if
self
.
temperature
is
not
None
:
temperature
=
self
.
temperature
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
),
[
temperature
]
*
len
(
inputs
)))
return
results
def
_generate
(
self
,
input
:
PromptType
,
max_out_len
:
int
,
temperature
:
float
)
->
str
:
"""Generate results given a list of inputs.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
# max num token for gpt-3.5-turbo is 4097
context_window
=
4096
if
'32k'
in
self
.
path
:
context_window
=
32768
elif
'16k'
in
self
.
path
:
context_window
=
16384
elif
'gpt-4'
in
self
.
path
:
context_window
=
8192
# will leave 100 tokens as prompt buffer, triggered if input is str
if
isinstance
(
input
,
str
)
and
self
.
mode
!=
'none'
:
context_window
=
self
.
max_seq_len
input
=
self
.
bin_trim
(
input
,
context_window
-
100
-
max_out_len
)
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
messages
=
[]
for
item
in
input
:
msg
=
{
'content'
:
item
[
'prompt'
]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'assistant'
elif
item
[
'role'
]
==
'SYSTEM'
:
msg
[
'role'
]
=
'system'
messages
.
append
(
msg
)
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len
=
min
(
max_out_len
,
context_window
-
self
.
get_token_len
(
str
(
input
))
-
100
)
if
max_out_len
<=
0
:
return
''
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
wait
()
header
=
{
'content-type'
:
'application/json'
,}
try
:
data
=
dict
(
model
=
'tgi'
,
messages
=
messages
,
max_tokens
=
max_out_len
,
n
=
1
,
logprobs
=
self
.
logprobs
,
top_logprobs
=
self
.
top_logprobs
,
stop
=
None
,
temperature
=
temperature
,
)
raw_response
=
requests
.
post
(
self
.
url
,
headers
=
header
,
data
=
json
.
dumps
(
data
))
except
requests
.
ConnectionError
:
self
.
logger
.
error
(
'Got connection error, retrying...'
)
continue
try
:
response
=
raw_response
.
json
()
except
requests
.
JSONDecodeError
:
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
continue
self
.
logger
.
debug
(
str
(
response
))
try
:
if
self
.
logprobs
:
return
response
[
'choices'
]
else
:
return
response
[
'choices'
][
0
][
'message'
][
'content'
].
strip
()
except
KeyError
:
if
'error'
in
response
:
if
response
[
'error'
][
'code'
]
==
'rate_limit_exceeded'
:
time
.
sleep
(
10
)
self
.
logger
.
warn
(
'Rate limit exceeded, retrying...'
)
continue
# elif response['error']['code'] == 'insufficient_quota':
# self.invalid_keys.add(key)
# self.logger.warn(f'insufficient_quota key: {key}')
# continue
elif
response
[
'error'
][
'code'
]
==
'invalid_prompt'
:
self
.
logger
.
warn
(
'Invalid prompt:'
,
str
(
input
))
return
''
self
.
logger
.
error
(
'Find error message in response: '
,
str
(
response
[
'error'
]))
max_num_retries
+=
1
raise
RuntimeError
(
'Calling OpenAI failed after retrying for '
f
'
{
max_num_retries
}
times. Check the logs for '
'details.'
)
def
get_token_len
(
self
,
prompt
:
str
)
->
int
:
"""Get lengths of the tokenized string. Only English and Chinese
characters are counted for now. Users are encouraged to override this
method if more accurate length is needed.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
# enc = self.tiktoken.encoding_for_model(self.path)
return
len
(
self
.
tokenizer
.
encode
(
prompt
))
def
bin_trim
(
self
,
prompt
:
str
,
num_token
:
int
)
->
str
:
"""Get a suffix of prompt which is no longer than num_token tokens.
Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.
Returns:
str: The trimmed prompt.
"""
token_len
=
self
.
get_token_len
(
prompt
)
if
token_len
<=
num_token
:
return
prompt
pattern
=
re
.
compile
(
r
'[\u4e00-\u9fa5]'
)
if
pattern
.
search
(
prompt
):
words
=
list
(
jieba
.
cut
(
prompt
,
cut_all
=
False
))
sep
=
''
else
:
words
=
prompt
.
split
(
' '
)
sep
=
' '
l
,
r
=
1
,
len
(
words
)
while
l
+
2
<
r
:
mid
=
(
l
+
r
)
//
2
if
self
.
mode
==
'front'
:
cur_prompt
=
sep
.
join
(
words
[
-
mid
:])
elif
self
.
mode
==
'mid'
:
cur_prompt
=
sep
.
join
(
words
[:
mid
])
+
sep
.
join
(
words
[
-
mid
:])
elif
self
.
mode
==
'rear'
:
cur_prompt
=
sep
.
join
(
words
[:
mid
])
if
self
.
get_token_len
(
cur_prompt
)
<=
num_token
:
l
=
mid
# noqa: E741
else
:
r
=
mid
if
self
.
mode
==
'front'
:
prompt
=
sep
.
join
(
words
[
-
l
:])
elif
self
.
mode
==
'mid'
:
prompt
=
sep
.
join
(
words
[:
l
])
+
sep
.
join
(
words
[
-
l
:])
elif
self
.
mode
==
'rear'
:
prompt
=
sep
.
join
(
words
[:
l
])
return
prompt
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment