Unverified Commit 1013dce6 authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

adapt to lmdeploy v0.4.0 (#1073)

* adapt to lmdeploy v0.4.0

* compatible
parent 58a57a4c
...@@ -50,6 +50,7 @@ class LmdeployPytorchModel(BaseModel): ...@@ -50,6 +50,7 @@ class LmdeployPytorchModel(BaseModel):
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template) meta_template=meta_template)
from lmdeploy.pytorch import engine as tm from lmdeploy.pytorch import engine as tm
from lmdeploy.version import version_info
if engine_config is not None: if engine_config is not None:
from lmdeploy.messages import PytorchEngineConfig from lmdeploy.messages import PytorchEngineConfig
...@@ -71,6 +72,7 @@ class LmdeployPytorchModel(BaseModel): ...@@ -71,6 +72,7 @@ class LmdeployPytorchModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)] self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config self.gen_config = gen_config
self.end_str = end_str self.end_str = end_str
self.major_version, self.minor_version, _ = version_info
def generate( def generate(
self, self,
...@@ -145,9 +147,16 @@ class LmdeployPytorchModel(BaseModel): ...@@ -145,9 +147,16 @@ class LmdeployPytorchModel(BaseModel):
assert type( assert type(
prompt) is str, 'We only support string for TurboMind Python API' prompt) is str, 'We only support string for TurboMind Python API'
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt)
if self.major_version >= 0 and self.minor_version >= 4:
outputs = generator.infer(session_id,
input_ids,
gen_config=gen_config)
output_ids = outputs.token_ids
else:
_, output_ids, _ = generator.infer(session_id, _, output_ids, _ = generator.infer(session_id,
input_ids, input_ids,
gen_config=gen_config) gen_config=gen_config)
# stop engine # stop engine
if hasattr(generator, 'end'): if hasattr(generator, 'end'):
generator.end(session_id) generator.end(session_id)
......
...@@ -54,6 +54,7 @@ class TurboMindModel(BaseModel): ...@@ -54,6 +54,7 @@ class TurboMindModel(BaseModel):
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template) meta_template=meta_template)
from lmdeploy.turbomind import TurboMind from lmdeploy.turbomind import TurboMind
from lmdeploy.version import version_info
if engine_config is not None: if engine_config is not None:
from lmdeploy.messages import TurbomindEngineConfig from lmdeploy.messages import TurbomindEngineConfig
...@@ -70,6 +71,7 @@ class TurboMindModel(BaseModel): ...@@ -70,6 +71,7 @@ class TurboMindModel(BaseModel):
self.generator_ids = [i + 1 for i in range(concurrency)] self.generator_ids = [i + 1 for i in range(concurrency)]
self.gen_config = gen_config self.gen_config = gen_config
self.end_str = end_str self.end_str = end_str
self.major_version, self.minor_version, _ = version_info
def generate(self, def generate(self,
inputs: List[str], inputs: List[str],
...@@ -165,6 +167,9 @@ class TurboMindModel(BaseModel): ...@@ -165,6 +167,9 @@ class TurboMindModel(BaseModel):
sequence_end=True, sequence_end=True,
step=0, step=0,
stream_output=False): stream_output=False):
if self.major_version >= 0 and self.minor_version >= 4:
output_ids = outputs.token_ids
else:
_, output_ids, _ = outputs _, output_ids, _ = outputs
response = self.tokenizer.decode(output_ids) response = self.tokenizer.decode(output_ids)
response = valid_str(response) response = valid_str(response)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment