Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
994bdb3f
Unverified
Commit
994bdb3f
authored
Feb 01, 2024
by
Baber Abbasi
Committed by
GitHub
Feb 01, 2024
Browse files
Hf: minor egde cases (#1380)
* edge cases where variable might not be assigned. * type hint
parent
f5408b6b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+4
-3
No files found.
lm_eval/models/huggingface.py
View file @
994bdb3f
...
@@ -108,8 +108,8 @@ class HFLM(LM):
...
@@ -108,8 +108,8 @@ class HFLM(LM):
assert
not
parallelize
,
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
assert
not
parallelize
,
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self
.
_model
=
pretrained
self
.
_model
=
pretrained
self
.
_device
=
self
.
_model
.
device
self
.
_device
=
self
.
_model
.
device
self
.
_config
=
self
.
_model
.
config
self
.
_config
=
self
.
_model
.
config
gpus
=
0
if
tokenizer
:
if
tokenizer
:
assert
isinstance
(
assert
isinstance
(
...
@@ -372,7 +372,7 @@ class HFLM(LM):
...
@@ -372,7 +372,7 @@ class HFLM(LM):
def
_get_backend
(
def
_get_backend
(
self
,
self
,
config
:
transformers
.
AutoConfig
,
config
:
Union
[
transformers
.
PretrainedConfig
,
transformers
.
AutoConfig
]
,
backend
:
Optional
[
Literal
[
"default"
,
"causal"
,
"seq2seq"
]]
=
"default"
,
backend
:
Optional
[
Literal
[
"default"
,
"causal"
,
"seq2seq"
]]
=
"default"
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
)
->
None
:
)
->
None
:
...
@@ -1059,6 +1059,7 @@ class HFLM(LM):
...
@@ -1059,6 +1059,7 @@ class HFLM(LM):
return
-
len
(
toks
),
x
[
0
]
return
-
len
(
toks
),
x
[
0
]
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
# using rolling window with maximum context
print
(
"Passed argument batch_size = auto. Detecting largest batch size"
)
print
(
"Passed argument batch_size = auto. Detecting largest batch size"
)
...
@@ -1103,7 +1104,7 @@ class HFLM(LM):
...
@@ -1103,7 +1104,7 @@ class HFLM(LM):
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
kwargs
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_
kwargs
)
}
"
)
)
if
not
until
:
if
not
until
:
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment