Unverified Commit 49c46745 authored by Leymore's avatar Leymore Committed by GitHub
Browse files

[Feature] Update llama2 (#372)

parent 3871188c
...@@ -59,12 +59,18 @@ class Llama2(BaseModel): ...@@ -59,12 +59,18 @@ class Llama2(BaseModel):
self.tokenizer = Tokenizer(tokenizer_path) self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, inputs: List[str], max_out_len: int) -> List[str]: def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
out = self.generator.text_completion( prompt_tokens = []
inputs, for input in inputs:
temperature=0, tokens = self.tokenizer.encode(input, True, False)
num_token = min(self.model.params.max_seq_len, len(tokens))
prompt_tokens.append(tokens[-num_token:])
generation_tokens, _ = self.generator.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_out_len, max_gen_len=max_out_len,
temperature=0,
) )
return [i['generation'] for i in out] results = [self.tokenizer.decode(t) for t in generation_tokens]
return results
def get_ppl(self, def get_ppl(self,
inputs: List[str], inputs: List[str],
...@@ -183,7 +189,7 @@ class Llama2Chat(BaseModel): ...@@ -183,7 +189,7 @@ class Llama2Chat(BaseModel):
) )
return [r['generation']['content'] for r in results] return [r['generation']['content'] for r in results]
except AssertionError: except AssertionError:
self.warning('Batched data max token limit exceeded, ' self.logger.warning('Batched data max token limit exceeded, '
'try to run one by one...') 'try to run one by one...')
results = [] results = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment