Commit 63eabe79 authored by lintangsutawika's avatar lintangsutawika
Browse files

added `unwrap_model` for generate

parent 6df13d93
...@@ -152,13 +152,23 @@ class HFLM(LM): ...@@ -152,13 +152,23 @@ class HFLM(LM):
return self.gpt2(inps)[0] return self.gpt2(inps)[0]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context, if hasattr(self, "accelerator"):
max_length=max_length, return self.accelerator.unwrap_model(self.gpt2).generate(
pad_token_id=eos_token_id, context,
eos_token_id=eos_token_id, max_length=max_length,
do_sample=False, pad_token_id=eos_token_id,
) eos_token_id=eos_token_id,
do_sample=False,
)
else:
return self.gpt2.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment