Commit e0498dd7 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

change self.gpt2 -> self.model

parent e6960b9a
...@@ -39,8 +39,8 @@ class HFLM(BaseLM): ...@@ -39,8 +39,8 @@ class HFLM(BaseLM):
# Initialize model # Initialize model
if isinstance(pretrained, transformers.PreTrainedModel): if isinstance(pretrained, transformers.PreTrainedModel):
self.gpt2 = pretrained self.model = pretrained
self._device = self.gpt2.device self._device = self.model.device
if tokenizer: if tokenizer:
assert isinstance( assert isinstance(
...@@ -53,7 +53,7 @@ class HFLM(BaseLM): ...@@ -53,7 +53,7 @@ class HFLM(BaseLM):
self.tokenizer = tokenizer self.tokenizer = tokenizer
else: else:
# Get tokenizer # Get tokenizer
model_name = self.gpt2.name_or_path model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, model_name,
revision=revision, revision=revision,
...@@ -81,7 +81,7 @@ class HFLM(BaseLM): ...@@ -81,7 +81,7 @@ class HFLM(BaseLM):
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
# Initialize new model and tokenizer instances # Initialize new model and tokenizer instances
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
...@@ -98,7 +98,7 @@ class HFLM(BaseLM): ...@@ -98,7 +98,7 @@ class HFLM(BaseLM):
else: else:
raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel') raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel')
self.gpt2.eval() self.model.eval()
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
...@@ -134,8 +134,8 @@ class HFLM(BaseLM): ...@@ -134,8 +134,8 @@ class HFLM(BaseLM):
return self._max_length return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs: for attr in seqlen_config_attrs:
if hasattr(self.gpt2.config, attr): if hasattr(self.model.config, attr):
return getattr(self.gpt2.config, attr) return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"): if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656: if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
...@@ -172,14 +172,14 @@ class HFLM(BaseLM): ...@@ -172,14 +172,14 @@ class HFLM(BaseLM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0] return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
generation_kwargs = {"do_sample": False, "max_length": max_length} generation_kwargs = {"do_sample": False, "max_length": max_length}
if eos_token_id is not None: if eos_token_id is not None:
generation_kwargs['eos_token_id'] = eos_token_id generation_kwargs['eos_token_id'] = eos_token_id
generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
return self.gpt2.generate(context, **generation_kwargs) return self.model.generate(context, **generation_kwargs)
# for backwards compatibility # for backwards compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment