Commit e0498dd7 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

change self.gpt2 -> self.model

parent e6960b9a
......@@ -39,8 +39,8 @@ class HFLM(BaseLM):
# Initialize model
if isinstance(pretrained, transformers.PreTrainedModel):
self.gpt2 = pretrained
self._device = self.gpt2.device
self.model = pretrained
self._device = self.model.device
if tokenizer:
assert isinstance(
......@@ -53,7 +53,7 @@ class HFLM(BaseLM):
self.tokenizer = tokenizer
else:
# Get tokenizer
model_name = self.gpt2.name_or_path
model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
......@@ -81,7 +81,7 @@ class HFLM(BaseLM):
revision = revision + ("/" + subfolder if subfolder is not None else "")
# Initialize new model and tokenizer instances
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
......@@ -98,7 +98,7 @@ class HFLM(BaseLM):
else:
raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel')
self.gpt2.eval()
self.model.eval()
self.vocab_size = self.tokenizer.vocab_size
......@@ -134,8 +134,8 @@ class HFLM(BaseLM):
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.gpt2.config, attr):
return getattr(self.gpt2.config, attr)
if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
......@@ -172,14 +172,14 @@ class HFLM(BaseLM):
logits returned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0]
return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id):
generation_kwargs = {"do_sample": False, "max_length": max_length}
if eos_token_id is not None:
generation_kwargs['eos_token_id'] = eos_token_id
generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
return self.gpt2.generate(context, **generation_kwargs)
return self.model.generate(context, **generation_kwargs)
# for backwards compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment