Unverified Commit 8194199d authored by jiangjin1999's avatar jiangjin1999 Committed by GitHub
Browse files

[Feature] *_batch_generate* function, add the MultiTokenEOSCriteria (#772)



* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference.

* jiangjin1999: in the _batch_generate function, add the MultiTokenEOSCriteria feature to speed up inference.

---------
Co-authored-by: default avatarjiangjin08 <jiangjin08@MBP-2F32S5MD6P-0029.local>
Co-authored-by: default avatarjiangjin08 <jiangjin08@a.sh.vip.dianping.com>
parent f78fcf6e
......@@ -241,6 +241,7 @@ class HuggingFace(BaseModel):
if self.batch_padding and len(inputs) > 1:
return self._batch_generate(inputs=inputs,
max_out_len=max_out_len,
stopping_criteria=stopping_criteria,
**generation_kwargs)
else:
return sum(
......@@ -250,7 +251,9 @@ class HuggingFace(BaseModel):
**generation_kwargs)
for input_ in inputs), [])
def _batch_generate(self, inputs: List[str], max_out_len: int,
def _batch_generate(self, inputs: List[str],
max_out_len: int,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]:
"""Support for batch prompts inference.
......@@ -289,6 +292,19 @@ class HuggingFace(BaseModel):
for k in tokens if k in ['input_ids', 'attention_mask']
}
if stopping_criteria:
# Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
stopping_criteria = transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, self.tokenizer,
tokens['input_ids'].shape[0])
for sequence in stopping_criteria
],
])
kwargs['stopping_criteria'] = stopping_criteria
# step-2: conduct model forward to generate output
outputs = self.model.generate(**tokens,
max_new_tokens=max_out_len,
......@@ -359,6 +375,7 @@ class HuggingFace(BaseModel):
if stopping_criteria:
# Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token]
stopping_criteria = transformers.StoppingCriteriaList([
*[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment