"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "e97e22dfcde0805379ffa25526a53835f887a860"
Unverified Commit 8194199d authored by jiangjin1999's avatar jiangjin1999 Committed by GitHub
Browse files

[Feature] *_batch_generate* function, add the MultiTokenEOSCriteria (#772)



* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference.

* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference.

---------
Co-authored-by: default avatarjiangjin08 <jiangjin08@MBP-2F32S5MD6P-0029.local>
Co-authored-by: default avatarjiangjin08 <jiangjin08@a.sh.vip.dianping.com>
parent f78fcf6e
...@@ -241,6 +241,7 @@ class HuggingFace(BaseModel): ...@@ -241,6 +241,7 @@ class HuggingFace(BaseModel):
if self.batch_padding and len(inputs) > 1: if self.batch_padding and len(inputs) > 1:
return self._batch_generate(inputs=inputs, return self._batch_generate(inputs=inputs,
max_out_len=max_out_len, max_out_len=max_out_len,
stopping_criteria=stopping_criteria,
**generation_kwargs) **generation_kwargs)
else: else:
return sum( return sum(
...@@ -250,7 +251,9 @@ class HuggingFace(BaseModel): ...@@ -250,7 +251,9 @@ class HuggingFace(BaseModel):
**generation_kwargs) **generation_kwargs)
for input_ in inputs), []) for input_ in inputs), [])
def _batch_generate(self, inputs: List[str], max_out_len: int, def _batch_generate(self, inputs: List[str],
max_out_len: int,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]: **kwargs) -> List[str]:
"""Support for batch prompts inference. """Support for batch prompts inference.
...@@ -289,6 +292,19 @@ class HuggingFace(BaseModel): ...@@ -289,6 +292,19 @@ class HuggingFace(BaseModel):
for k in tokens if k in ['input_ids', 'attention_mask'] for k in tokens if k in ['input_ids', 'attention_mask']
} }
if stopping_criteria:
# Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
stopping_criteria = transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, self.tokenizer,
tokens['input_ids'].shape[0])
for sequence in stopping_criteria
],
])
kwargs['stopping_criteria'] = stopping_criteria
# step-2: conduct model forward to generate output # step-2: conduct model forward to generate output
outputs = self.model.generate(**tokens, outputs = self.model.generate(**tokens,
max_new_tokens=max_out_len, max_new_tokens=max_out_len,
...@@ -359,6 +375,7 @@ class HuggingFace(BaseModel): ...@@ -359,6 +375,7 @@ class HuggingFace(BaseModel):
if stopping_criteria: if stopping_criteria:
# Construct huggingface stopping criteria # Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token] stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
stopping_criteria = transformers.StoppingCriteriaList([ stopping_criteria = transformers.StoppingCriteriaList([
*[ *[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment