Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0d8e206c
Unverified
Commit
0d8e206c
authored
Jul 03, 2023
by
Lintang Sutawika
Committed by
GitHub
Jul 03, 2023
Browse files
Merge pull request #647 from EleutherAI/handle-multigpu-errors
[Refactor] Handle `cuda:0` device assignment
parents
59aef189
b3598058
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
6 deletions
+14
-6
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+8
-4
lm_eval/utils.py
lm_eval/utils.py
+6
-2
No files found.
lm_eval/models/huggingface.py
View file @
0d8e206c
...
...
@@ -93,11 +93,16 @@ class HFLM(LM):
assert
isinstance
(
batch_size
,
int
)
gpus
=
torch
.
cuda
.
device_count
()
accelerator
=
Accelerator
()
if
gpus
<=
1
and
not
parallelize
:
if
not
(
parallelize
or
accelerator
.
num_processes
>
1
)
:
# use user-passed device
device_list
=
set
(
[
"cuda"
,
"cpu"
]
+
[
f
"cuda:
{
i
}
"
for
i
in
range
(
torch
.
cuda
.
device_count
())]
)
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]
:
if
device
not
in
device_list
:
device
=
int
(
device
)
self
.
_device
=
torch
.
device
(
device
)
eval_logger
.
info
(
f
"Using device '
{
device
}
'"
)
...
...
@@ -111,7 +116,7 @@ class HFLM(LM):
)
else
:
eval_logger
.
info
(
f
"
Passed device '
{
device
}
', but u
sing `accelerate launch` or `parallelize=True`
. This
will be overridden when placing model."
f
"
U
sing `accelerate launch` or `parallelize=True`
, device '
{
device
}
'
will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too
self
.
_device
=
device
...
...
@@ -217,7 +222,6 @@ class HFLM(LM):
# multigpu data-parallel support when launched with accelerate
if
gpus
>
1
:
accelerator
=
Accelerator
()
if
parallelize
:
if
accelerator
.
num_processes
>
1
:
raise
RuntimeError
(
...
...
lm_eval/utils.py
View file @
0d8e206c
...
...
@@ -10,7 +10,7 @@ import collections
import
importlib.util
import
fnmatch
from
typing
import
List
,
Union
from
typing
import
List
,
Literal
,
Union
import
gc
import
torch
...
...
@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
],
padding_side
=
"right"
):
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
],
padding_side
:
Literal
[
"right"
,
"left"
]
=
"right"
,
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment