"...composable_kernel.git" did not exist on "827301d95af93d581ddac8d2734ec759ea215c6c"
Unverified Commit 721a45c6 authored by Songyang Zhang's avatar Songyang Zhang Committed by GitHub
Browse files

[Bug] Update api with generation_kargs (#614)



* update api

* update generation_kwargs impl

---------
Co-authored-by: default avatarLeymore <zfz-960727@163.com>
parent eb56fd6d
...@@ -19,6 +19,8 @@ class BaseModel: ...@@ -19,6 +19,8 @@ class BaseModel:
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
""" """
is_api: bool = False is_api: bool = False
...@@ -27,7 +29,8 @@ class BaseModel: ...@@ -27,7 +29,8 @@ class BaseModel:
path: str, path: str,
max_seq_len: int = 2048, max_seq_len: int = 2048,
tokenizer_only: bool = False, tokenizer_only: bool = False,
meta_template: Optional[Dict] = None): meta_template: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = dict()):
self.path = path self.path = path
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only self.tokenizer_only = tokenizer_only
...@@ -36,6 +39,7 @@ class BaseModel: ...@@ -36,6 +39,7 @@ class BaseModel:
self.eos_token_id = None self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template: if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id'] self.eos_token_id = meta_template['eos_token_id']
self.generation_kwargs = generation_kwargs
@abstractmethod @abstractmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]: def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
......
...@@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel): ...@@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel):
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
""" """
is_api: bool = True is_api: bool = True
...@@ -37,7 +39,8 @@ class BaseAPIModel(BaseModel): ...@@ -37,7 +39,8 @@ class BaseAPIModel(BaseModel):
query_per_second: int = 1, query_per_second: int = 1,
retry: int = 2, retry: int = 2,
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None): meta_template: Optional[Dict] = None,
generation_kwargs: Dict = dict()):
self.path = path self.path = path
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.meta_template = meta_template self.meta_template = meta_template
...@@ -46,6 +49,7 @@ class BaseAPIModel(BaseModel): ...@@ -46,6 +49,7 @@ class BaseAPIModel(BaseModel):
self.token_bucket = TokenBucket(query_per_second) self.token_bucket = TokenBucket(query_per_second)
self.template_parser = APITemplateParser(meta_template) self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger() self.logger = get_logger()
self.generation_kwargs = generation_kwargs
@abstractmethod @abstractmethod
def generate(self, inputs: List[PromptType], def generate(self, inputs: List[PromptType],
......
...@@ -16,25 +16,22 @@ class LightllmAPI(BaseAPIModel): ...@@ -16,25 +16,22 @@ class LightllmAPI(BaseAPIModel):
is_api: bool = True is_api: bool = True
def __init__( def __init__(
self, self,
path: str = 'LightllmAPI', path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate', url: str = 'http://localhost:8080/generate',
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
retry: int = 2, retry: int = 2,
generation_kwargs: Optional[Dict] = None, generation_kwargs: Optional[Dict] = dict(),
): ):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template, meta_template=meta_template,
retry=retry) retry=retry,
generation_kwargs=generation_kwargs)
self.logger = get_logger() self.logger = get_logger()
self.url = url self.url = url
if generation_kwargs is not None:
self.generation_kwargs = generation_kwargs
else:
self.generation_kwargs = {}
self.do_sample = self.generation_kwargs.get('do_sample', False) self.do_sample = self.generation_kwargs.get('do_sample', False)
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False) self.ignore_eos = self.generation_kwargs.get('ignore_eos', False)
......
...@@ -54,7 +54,6 @@ class TurboMindModel(BaseModel): ...@@ -54,7 +54,6 @@ class TurboMindModel(BaseModel):
tm_model.create_instance() for i in range(concurrency) tm_model.create_instance() for i in range(concurrency)
] ]
self.generator_ids = [i + 1 for i in range(concurrency)] self.generator_ids = [i + 1 for i in range(concurrency)]
self.generation_kwargs = dict()
def generate( def generate(
self, self,
......
...@@ -53,7 +53,6 @@ class TurboMindTisModel(BaseModel): ...@@ -53,7 +53,6 @@ class TurboMindTisModel(BaseModel):
if meta_template and 'eos_token_id' in meta_template: if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id'] self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr self.tis_addr = tis_addr
self.generation_kwargs = dict()
def generate( def generate(
self, self,
......
...@@ -130,7 +130,7 @@ class GenInferencer(BaseInferencer): ...@@ -130,7 +130,7 @@ class GenInferencer(BaseInferencer):
entry, max_out_len=self.max_out_len) entry, max_out_len=self.max_out_len)
generated = results generated = results
num_return_sequences = self.model.generation_kwargs.get( num_return_sequences = self.model.get('generation_kwargs', {}).get(
'num_return_sequences', 1) 'num_return_sequences', 1)
# 5-3. Save current output # 5-3. Save current output
for prompt, prediction, gold in zip( for prompt, prediction, gold in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment