Unverified Commit 77745a84 authored by LZHgrla's avatar LZHgrla Committed by GitHub
Browse files

[Fix] Fix bugs for PeftModel generate (#252)

* fix bugs

* fix typo
parent 2a5cef29
......@@ -203,7 +203,9 @@ class HuggingFace(BaseModel):
max_length=self.max_seq_len -
max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(input_ids,
# To accommodate the PeftModel, parameters should be passed in
# key-value format for generate.
outputs = self.model.generate(input_ids=input_ids,
max_new_tokens=max_out_len,
**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment