Unverified Commit 77745a84 authored by LZHgrla's avatar LZHgrla Committed by GitHub
Browse files

[Fix] Fix bugs for PeftModel generate (#252)

* fix bugs

* fix typo
parent 2a5cef29
...@@ -203,7 +203,9 @@ class HuggingFace(BaseModel): ...@@ -203,7 +203,9 @@ class HuggingFace(BaseModel):
max_length=self.max_seq_len - max_length=self.max_seq_len -
max_out_len)['input_ids'] max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device) input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(input_ids, # To accommodate the PeftModel, parameters should be passed in
# key-value format for generate.
outputs = self.model.generate(input_ids=input_ids,
max_new_tokens=max_out_len, max_new_tokens=max_out_len,
**kwargs) **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment