Commit bc7f52e6 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

automatically unwrap model when needed

parent 1bd6229c
...@@ -74,9 +74,10 @@ class HFLM(LM): ...@@ -74,9 +74,10 @@ class HFLM(LM):
assert self.AUTO_MODEL_CLASS in [transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM] assert self.AUTO_MODEL_CLASS in [transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM]
self.model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device) ).to(self.device)
# forever after, access self._model through self.model property
self.model.eval() self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
...@@ -125,28 +126,27 @@ class HFLM(LM): ...@@ -125,28 +126,27 @@ class HFLM(LM):
# return the associated transformers.AutoConfig for the given pretrained model. # return the associated transformers.AutoConfig for the given pretrained model.
return self._config return self._config
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
@property @property
def max_length(self): def max_length(self):
try: try:
if hasattr(self, "accelerator"): return self.model.config.n_ctx
return self.accelerator.unwrap_model(self.model).config.n_ctx
else:
return self.model.config.n_ctx
except AttributeError: except AttributeError:
# gptneoconfig doesn't have n_ctx apparently # gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"): return self.model.config.max_position_embeddings
return self.accelerator.unwrap_model(
self.model
).config.max_position_embeddings
else:
return self.model.config.max_position_embeddings
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 256
...@@ -236,24 +236,14 @@ class HFLM(LM): ...@@ -236,24 +236,14 @@ class HFLM(LM):
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, 1, context.shape[0]
) )
if hasattr(self, "accelerator"): return self.model.generate(
return self.accelerator.unwrap_model(self.model).generate( context,
context, max_length=max_length,
max_length=max_length, stopping_criteria=stopping_criteria,
stopping_criteria=stopping_criteria, pad_token_id=self.eot_token_id,
pad_token_id=self.eot_token_id, use_cache=True,
use_cache=True, **generation_kwargs,
**generation_kwargs, )
)
else:
return self.model.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
...@@ -299,7 +289,7 @@ class HFLM(LM): ...@@ -299,7 +289,7 @@ class HFLM(LM):
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder #TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment