Commit 7ab782ec authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

use Seq2Seq backend where either can be loaded from HF

parent a2992d28
......@@ -157,12 +157,17 @@ class HFLM(LM):
trust_remote_code=trust_remote_code,
)
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif (
not getattr(self._config, "model_type")
if (
getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
elif getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
......@@ -171,8 +176,6 @@ class HFLM(LM):
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment