Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
dd2b2de9
Commit
dd2b2de9
authored
Jun 19, 2023
by
haileyschoelkopf
Browse files
re-add dtype flag
parent
1fb16673
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
2 deletions
+20
-2
lm_eval/models/hf_causal.py
lm_eval/models/hf_causal.py
+8
-2
lm_eval/utils.py
lm_eval/utils.py
+12
-0
No files found.
lm_eval/models/hf_causal.py
View file @
dd2b2de9
...
...
@@ -12,7 +12,7 @@ from lm_eval.api.model import LM
from
lm_eval.api.registry
import
register_model
from
accelerate
import
Accelerator
from
itertools
import
islice
from
typing
import
Optional
,
Union
@
register_model
(
"hf-causal"
)
...
...
@@ -23,6 +23,7 @@ class HFLM(LM):
pretrained
=
"gpt2"
,
revision
=
"main"
,
low_cpu_mem_usage
=
None
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
,
...
...
@@ -58,10 +59,15 @@ class HFLM(LM):
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
self
.
model
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
torch_dtype
=
utils
.
get_dtype
(
dtype
),
).
to
(
self
.
device
)
self
.
model
.
eval
()
print
(
self
.
model
.
dtype
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
if
tokenizer
is
None
else
tokenizer
,
revision
=
revision
,
...
...
lm_eval/utils.py
View file @
dd2b2de9
...
...
@@ -419,3 +419,15 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
def
clear_torch_cache
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_dtype
(
dtype
:
Union
[
str
,
torch
.
dtype
]
)
->
torch
.
dtype
:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if
isinstance
(
dtype
,
str
)
and
dtype
!=
"auto"
:
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype
=
getattr
(
torch
,
dtype
)
else
:
_torch_dtype
=
dtype
return
_torch_dtype
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment