Unverified Commit 49c46745 authored by Leymore's avatar Leymore Committed by GitHub
Browse files

[Feature] Update llama2 (#372)

parent 3871188c
......@@ -59,12 +59,18 @@ class Llama2(BaseModel):
self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
out = self.generator.text_completion(
inputs,
temperature=0,
prompt_tokens = []
for input in inputs:
tokens = self.tokenizer.encode(input, True, False)
num_token = min(self.model.params.max_seq_len, len(tokens))
prompt_tokens.append(tokens[-num_token:])
generation_tokens, _ = self.generator.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_out_len,
temperature=0,
)
return [i['generation'] for i in out]
results = [self.tokenizer.decode(t) for t in generation_tokens]
return results
def get_ppl(self,
inputs: List[str],
......@@ -183,7 +189,7 @@ class Llama2Chat(BaseModel):
)
return [r['generation']['content'] for r in results]
except AssertionError:
self.warning('Batched data max token limit exceeded, '
self.logger.warning('Batched data max token limit exceeded, '
'try to run one by one...')
results = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment