Unverified Commit 72b7f0c0 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #601 from svenhendrikx/instantiate-model-from-Automodel

Instantiate model from automodel
parents 9f4862f6 72ee34dc
import collections
import itertools
import numpy as np
import random
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.models.gpt2 import HFLM
import numpy as np
import transformers
@positional_deprecated
......@@ -30,7 +34,7 @@ def simple_evaluate(
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
Name of model, transformers.PreTrainedModel object, or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
......@@ -72,6 +76,12 @@ def simple_evaluate(
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
)
elif isinstance(model, transformers.PreTrainedModel):
lm = lm_eval.models.get_model("hf-causal")(
pretrained=model,
batch_size=batch_size,
)
no_cache = True
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
......@@ -104,8 +114,13 @@ def simple_evaluate(
)
# add info about the model and few shot config
model_name = None
if isinstance(model, str):
model_name = model
elif isinstance(model, transformers.PreTrainedModel):
model_name = "pretrained=" + model.config._name_or_path
results["config"] = {
"model": (model if isinstance(model, str) else model.model.config._name_or_path),
"model": model_name,
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
......
......@@ -29,56 +29,82 @@ class HFLM(BaseLM):
subfolder=None,
tokenizer=None,
batch_size=1,
max_length=None,
max_length=None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
):
super().__init__()
# Initialize model
if isinstance(pretrained, transformers.PreTrainedModel):
self.model = pretrained
self._device = self.model.device
if tokenizer:
assert isinstance(
tokenizer,
transformers.PreTrainedTokenizer
) or isinstance(
tokenizer,
transformers.PreTrainedTokenizerFast
)
self.tokenizer = tokenizer
else:
# Get tokenizer
model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
)
elif isinstance(pretrained, str):
# Initialize device
assert isinstance(device, str)
device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
revision = revision + ("/" + subfolder if subfolder is not None else "")
# Initialize new model and tokenizer instances
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).to(self.device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer if tokenizer else pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel')
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).eval()
if not load_in_8bit:
try:
self.gpt2.to(self.device)
except:
print("Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore.")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
)
self.model.eval()
self.vocab_size = self.tokenizer.vocab_size
# Validate batch_size
assert isinstance(batch_size, (int, str))
# setup for automatic batch size detection
if batch_size == "auto":
self.batch_size_per_gpu = batch_size
......@@ -98,8 +124,8 @@ class HFLM(BaseLM):
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.gpt2.config, attr):
return getattr(self.gpt2.config, attr)
if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
......@@ -136,14 +162,14 @@ class HFLM(BaseLM):
logits returned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0]
return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id):
generation_kwargs = {"do_sample": False, "max_length": max_length}
if eos_token_id is not None:
generation_kwargs['eos_token_id'] = eos_token_id
generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
return self.gpt2.generate(context, **generation_kwargs)
return self.model.generate(context, **generation_kwargs)
# for backwards compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment