Commit 380c62f9 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

handle trust_remote_code models better

parent 1e2fbe6b
import torch
import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
from peft import __version__ as PEFT_VERSION, PeftModel
import copy
......@@ -147,6 +150,18 @@ class HFLM(LM):
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif (
not getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
)
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment