Commit c10c08a2 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

unwrap gpt2 model to get n_ctx

parent e4f1dfb6
......@@ -100,7 +100,10 @@ class HFLM(LM):
@property
def max_length(self):
try:
return self.gpt2.config.n_ctx
if hasattr(self, 'accelerator'):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
else:
return self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, 'accelerator'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment