Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
bd81b8c0
Commit
bd81b8c0
authored
Sep 14, 2023
by
baberabb
Browse files
Update device list and dtype detection for MPS
parent
b8d1cef9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
6 deletions
+29
-6
lm_eval/api/model.py
lm_eval/api/model.py
+23
-3
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+6
-3
No files found.
lm_eval/api/model.py
View file @
bd81b8c0
import
abc
import
os
from
typing
import
Union
,
List
,
Tuple
import
torch
from
typing
import
Union
,
List
,
Tuple
,
Optional
,
Type
,
TypeVar
from
sqlitedict
import
SqliteDict
import
json
import
hashlib
...
...
@@ -11,6 +12,8 @@ from tqdm import tqdm
from
lm_eval
import
utils
from
lm_eval.logger
import
eval_logger
T
=
TypeVar
(
"T"
,
bound
=
"LM"
)
class
LM
(
abc
.
ABC
):
def
__init__
(
self
)
->
None
:
...
...
@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
,
additional_config
=
None
):
def
create_from_arg_string
(
cls
:
Type
[
T
],
arg_string
:
str
,
additional_config
:
Optional
[
dict
]
=
None
)
->
T
:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config
=
{}
if
additional_config
is
None
else
additional_config
args
=
utils
.
simple_parse_args_string
(
arg_string
)
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
if
args2
.
get
(
"device"
)
==
"mps"
or
args
.
get
(
"device"
)
==
"mps"
:
# TODO: delete once float16 MPS is fixed in torch stable
if
(
args2
.
get
(
"device"
)
in
(
"mps"
,
"mps:0"
)
or
args
.
get
(
"device"
)
in
(
"mps"
,
"mps:0"
)
and
"dev"
not
in
torch
.
__version__
):
args
[
"dtype"
]
=
"float32"
return
cls
(
**
args
,
**
args2
)
...
...
lm_eval/models/huggingface.py
View file @
bd81b8c0
...
...
@@ -107,17 +107,20 @@ class HFLM(LM):
if
not
(
parallelize
or
accelerator
.
num_processes
>
1
):
# use user-passed device
device_list
=
set
(
[
"cuda"
,
"cpu"
,
"mps"
]
[
"cuda"
,
"cpu"
]
+
[
f
"cuda:
{
i
}
"
for
i
in
range
(
torch
.
cuda
.
device_count
())]
+
[
"mps"
,
"mps:0"
]
)
if
device
:
if
device
not
in
device_list
:
device
=
int
(
device
)
self
.
_device
=
torch
.
device
(
device
)
eval_logger
.
info
(
f
"Using device '
{
device
}
'"
)
if
device
==
"mps"
:
if
device
in
(
"mps"
,
"mps:0"
)
and
"dev"
not
in
torch
.
__version__
:
eval_logger
.
info
(
"MPS is still in beta and only supports float32; setting dtype to float32."
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
)
else
:
eval_logger
.
info
(
"Device not specified"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment