Unverified Commit 0d8e206c authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #647 from EleutherAI/handle-multigpu-errors

[Refactor] Handle `cuda:0` device assignment
parents 59aef189 b3598058
...@@ -93,11 +93,16 @@ class HFLM(LM): ...@@ -93,11 +93,16 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator = Accelerator()
if gpus <= 1 and not parallelize: if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device: if device:
if device not in ["cuda", "cpu"]: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
...@@ -111,7 +116,7 @@ class HFLM(LM): ...@@ -111,7 +116,7 @@ class HFLM(LM):
) )
else: else:
eval_logger.info( eval_logger.info(
f"Passed device '{device}', but using `accelerate launch` or `parallelize=True`. This will be overridden when placing model." f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
) )
# TODO: include in warning that `load_in_8bit` etc. affect this too # TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device self._device = device
...@@ -217,7 +222,6 @@ class HFLM(LM): ...@@ -217,7 +222,6 @@ class HFLM(LM):
# multigpu data-parallel support when launched with accelerate # multigpu data-parallel support when launched with accelerate
if gpus > 1: if gpus > 1:
accelerator = Accelerator()
if parallelize: if parallelize:
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
raise RuntimeError( raise RuntimeError(
......
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import importlib.util import importlib.util
import fnmatch import fnmatch
from typing import List, Union from typing import List, Literal, Union
import gc import gc
import torch import torch
...@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"): def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
""" """
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment