Unverified Commit 72ee34dc authored by Sven Hendrikx's avatar Sven Hendrikx Committed by GitHub
Browse files

Merge pull request #1 from EleutherAI/pass-automodel

Fixes for passing AutoModel
parents e6960b9a 1e98c74e
...@@ -34,7 +34,7 @@ def simple_evaluate( ...@@ -34,7 +34,7 @@ def simple_evaluate(
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model Name of model, transformers.PreTrainedModel object, or LM object, see lm_eval.models.get_model
:param model_args: Optional[str] :param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string. String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object. Ignored if `model` argument is a LM object.
...@@ -77,8 +77,9 @@ def simple_evaluate( ...@@ -77,8 +77,9 @@ def simple_evaluate(
model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device} model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
) )
elif isinstance(model, transformers.PreTrainedModel): elif isinstance(model, transformers.PreTrainedModel):
lm = HFLM( lm = lm_eval.models.get_model("hf-causal")(
pretrained=model, pretrained=model,
batch_size=batch_size,
) )
no_cache = True no_cache = True
else: else:
...@@ -113,8 +114,13 @@ def simple_evaluate( ...@@ -113,8 +114,13 @@ def simple_evaluate(
) )
# add info about the model and few shot config # add info about the model and few shot config
model_name = None
if isinstance(model, str):
model_name = model
elif isinstance(model, transformers.PreTrainedModel):
model_name = "pretrained=" + model.config._name_or_path
results["config"] = { results["config"] = {
"model": (model if isinstance(model, str) else model.model.config._name_or_path), "model": model_name,
"model_args": model_args, "model_args": model_args,
"num_fewshot": num_fewshot, "num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
......
...@@ -39,8 +39,8 @@ class HFLM(BaseLM): ...@@ -39,8 +39,8 @@ class HFLM(BaseLM):
# Initialize model # Initialize model
if isinstance(pretrained, transformers.PreTrainedModel): if isinstance(pretrained, transformers.PreTrainedModel):
self.gpt2 = pretrained self.model = pretrained
self._device = self.gpt2.device self._device = self.model.device
if tokenizer: if tokenizer:
assert isinstance( assert isinstance(
...@@ -53,7 +53,7 @@ class HFLM(BaseLM): ...@@ -53,7 +53,7 @@ class HFLM(BaseLM):
self.tokenizer = tokenizer self.tokenizer = tokenizer
else: else:
# Get tokenizer # Get tokenizer
model_name = self.gpt2.name_or_path model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, model_name,
revision=revision, revision=revision,
...@@ -81,7 +81,7 @@ class HFLM(BaseLM): ...@@ -81,7 +81,7 @@ class HFLM(BaseLM):
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
# Initialize new model and tokenizer instances # Initialize new model and tokenizer instances
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
...@@ -98,20 +98,10 @@ class HFLM(BaseLM): ...@@ -98,20 +98,10 @@ class HFLM(BaseLM):
else: else:
raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel') raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel')
self.gpt2.eval() self.model.eval()
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# Validate batch_size # Validate batch_size
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
...@@ -134,8 +124,8 @@ class HFLM(BaseLM): ...@@ -134,8 +124,8 @@ class HFLM(BaseLM):
return self._max_length return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs: for attr in seqlen_config_attrs:
if hasattr(self.gpt2.config, attr): if hasattr(self.model.config, attr):
return getattr(self.gpt2.config, attr) return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"): if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656: if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
...@@ -172,14 +162,14 @@ class HFLM(BaseLM): ...@@ -172,14 +162,14 @@ class HFLM(BaseLM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0] return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
generation_kwargs = {"do_sample": False, "max_length": max_length} generation_kwargs = {"do_sample": False, "max_length": max_length}
if eos_token_id is not None: if eos_token_id is not None:
generation_kwargs['eos_token_id'] = eos_token_id generation_kwargs['eos_token_id'] = eos_token_id
generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
return self.gpt2.generate(context, **generation_kwargs) return self.model.generate(context, **generation_kwargs)
# for backwards compatibility # for backwards compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment