Unverified Commit e415ddf9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

[Fix] Fix turbomind_tis (#992)

parent 054e9fa7
...@@ -6,7 +6,7 @@ with read_base(): ...@@ -6,7 +6,7 @@ with read_base():
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_6dc406 import WSC_datasets from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
......
...@@ -47,6 +47,8 @@ class TurboMindTisModel(BaseModel): ...@@ -47,6 +47,8 @@ class TurboMindTisModel(BaseModel):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template) meta_template=meta_template)
from lmdeploy.serve.turbomind.utils import Preprocessor
self.preprocess = Preprocessor(tis_addr)
self.logger = get_logger() self.logger = get_logger()
self.template_parser = LMTemplateParser(meta_template) self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None self.eos_token_id = None
...@@ -83,6 +85,10 @@ class TurboMindTisModel(BaseModel): ...@@ -83,6 +85,10 @@ class TurboMindTisModel(BaseModel):
[temperature] * len(inputs))) [temperature] * len(inputs)))
return results return results
def get_token_len(self, prompt: str) -> int:
input_ids, _ = self.preprocess(prompt)
return input_ids.shape[-1]
def wait(self): def wait(self):
"""Wait till the next query can be sent. """Wait till the next query can be sent.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment