Commit c10c08a2 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

unwrap gpt2 model to get n_ctx

parent e4f1dfb6
...@@ -100,7 +100,10 @@ class HFLM(LM): ...@@ -100,7 +100,10 @@ class HFLM(LM):
@property @property
def max_length(self): def max_length(self):
try: try:
return self.gpt2.config.n_ctx if hasattr(self, 'accelerator'):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
else:
return self.gpt2.config.n_ctx
except AttributeError: except AttributeError:
# gptneoconfig doesn't have n_ctx apparently # gptneoconfig doesn't have n_ctx apparently
if hasattr(self, 'accelerator'): if hasattr(self, 'accelerator'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment