Commit a757f293 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

automatically unwrap model when needed

parent 605b1cef
......@@ -74,9 +74,10 @@ class HFLM(LM):
assert self.AUTO_MODEL_CLASS in [transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM]
self.model = self.AUTO_MODEL_CLASS.from_pretrained(
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device)
# forever after, access self._model through self.model property
self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
......@@ -125,28 +126,27 @@ class HFLM(LM):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
@property
def max_length(self):
try:
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).config.n_ctx
else:
return self.model.config.n_ctx
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.model
).config.max_position_embeddings
else:
return self.model.config.max_position_embeddings
return self.model.config.max_position_embeddings
@property
def max_gen_toks(self):
return 256
......@@ -236,24 +236,14 @@ class HFLM(LM):
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
return self.model.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -299,7 +289,7 @@ class HFLM(LM):
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment