"vscode:/vscode.git/clone" did not exist on "5ec12cec6c097a4d3706edb0fa0e51f02dfc1b4c"
Commit 63eabe79 authored by lintangsutawika's avatar lintangsutawika
Browse files

added `unwrap_model` for generate

parent 6df13d93
...@@ -152,6 +152,16 @@ class HFLM(LM): ...@@ -152,6 +152,16 @@ class HFLM(LM):
return self.gpt2(inps)[0] return self.gpt2(inps)[0]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
)
else:
return self.gpt2.generate( return self.gpt2.generate(
context, context,
max_length=max_length, max_length=max_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment