"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "6e923dbd30411b89876ec465d1c95282225ba85e"
Unverified Commit e6d4ec39 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #1024 from EleutherAI/fix-mbart

[Refactor] Use correct HF model type for MBart-like models
parents b072bb0d 7ab782ec
......@@ -158,12 +158,17 @@ class HFLM(LM):
trust_remote_code=trust_remote_code,
)
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif (
not getattr(self._config, "model_type")
if (
getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
elif getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
......@@ -172,8 +177,6 @@ class HFLM(LM):
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment