Commit ddc634f2 authored by svenhendrikx's avatar svenhendrikx
Browse files

Add logic to simple_evaluate to instantiate HFLM from transformers.PreTrainedModel instance

parent 42caa660
import collections import collections
import itertools import itertools
import numpy as np
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.models.gpt2 import HFLM
import numpy as np
import transformers
@positional_deprecated @positional_deprecated
...@@ -69,6 +73,11 @@ def simple_evaluate( ...@@ -69,6 +73,11 @@ def simple_evaluate(
lm = lm_eval.models.get_model(model).create_from_arg_string( lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device} model_args, {"batch_size": batch_size, "device": device}
) )
elif isinstance(model, transformers.PreTrainedModel):
lm = HFLM(
pretrained=model,
)
no_cache = True
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment