Unverified Commit e6d4ec39 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #1024 from EleutherAI/fix-mbart

[Refactor] Use correct HF model type for MBart-like models
parents b072bb0d 7ab782ec
...@@ -158,12 +158,17 @@ class HFLM(LM): ...@@ -158,12 +158,17 @@ class HFLM(LM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if (
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM getattr(self._config, "model_type")
elif (
not getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
): ):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
elif getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
if not trust_remote_code: if not trust_remote_code:
eval_logger.warning( eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
...@@ -172,8 +177,6 @@ class HFLM(LM): ...@@ -172,8 +177,6 @@ class HFLM(LM):
# if model type is neither in HF transformers causal or seq2seq model registries # if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM # then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [ assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM, transformers.AutoModelForCausalLM,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment