Unverified Commit 1013dce6 authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

adapt to lmdeploy v0.4.0 (#1073)

* adapt to lmdeploy v0.4.0

* compatible
parent 58a57a4c
......@@ -50,6 +50,7 @@ class LmdeployPytorchModel(BaseModel):
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.pytorch import engine as tm
from lmdeploy.version import version_info
if engine_config is not None:
from lmdeploy.messages import PytorchEngineConfig
......@@ -71,6 +72,7 @@ class LmdeployPytorchModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version, _ = version_info
def generate(
self,
......@@ -145,9 +147,16 @@ class LmdeployPytorchModel(BaseModel):
assert type(
prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt)
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)
if self.major_version >= 0 and self.minor_version >= 4:
outputs = generator.infer(session_id,
input_ids,
gen_config=gen_config)
output_ids = outputs.token_ids
else:
_, output_ids, _ = generator.infer(session_id,
input_ids,
gen_config=gen_config)
# stop engine
if hasattr(generator, 'end'):
generator.end(session_id)
......
......@@ -54,6 +54,7 @@ class TurboMindModel(BaseModel):
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.turbomind import TurboMind
from lmdeploy.version import version_info
if engine_config is not None:
from lmdeploy.messages import TurbomindEngineConfig
......@@ -70,6 +71,7 @@ class TurboMindModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config
self.end_str = end_str
self.major_version, self.minor_version, _ = version_info
def generate(self,
inputs: List[str],
......@@ -165,7 +167,10 @@ class TurboMindModel(BaseModel):
sequence_end=True,
step=0,
stream_output=False):
_, output_ids, _ = outputs
if self.major_version >= 0 and self.minor_version >= 4:
output_ids = outputs.token_ids
else:
_, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids)
response = valid_str(response)
# used to trim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment