Commit 65127577 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

unwrap model to get pos embs

parent f1372139
......@@ -103,7 +103,10 @@ class HFLM(LM):
return self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
return self.gpt2.config.max_position_embeddings
if hasattr(self, 'accelerator'):
return self.accelerator.unwrap_model(self.gpt2).config.max_position_embeddings
else:
return self.gpt2.config.max_position_embeddings
@property
def max_gen_toks(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment