Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e6d4ec39
Unverified
Commit
e6d4ec39
authored
Nov 28, 2023
by
Stella Biderman
Committed by
GitHub
Nov 28, 2023
Browse files
Merge pull request #1024 from EleutherAI/fix-mbart
[Refactor] Use correct HF model type for MBart-like models
parents
b072bb0d
7ab782ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+9
-6
No files found.
lm_eval/models/huggingface.py
View file @
e6d4ec39
...
@@ -158,12 +158,17 @@ class HFLM(LM):
...
@@ -158,12 +158,17 @@ class HFLM(LM):
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
if
(
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
getattr
(
self
.
_config
,
"model_type"
)
elif
(
not
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
in
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
elif
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
else
:
if
not
trust_remote_code
:
if
not
trust_remote_code
:
eval_logger
.
warning
(
eval_logger
.
warning
(
"HF model type is neither marked as CausalLM or Seq2SeqLM.
\
"HF model type is neither marked as CausalLM or Seq2SeqLM.
\
...
@@ -172,8 +177,6 @@ class HFLM(LM):
...
@@ -172,8 +177,6 @@ class HFLM(LM):
# if model type is neither in HF transformers causal or seq2seq model registries
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
# then we default to AutoModelForCausalLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
else
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
in
[
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForCausalLM
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment