Unverified Commit 0d8e206c authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #647 from EleutherAI/handle-multigpu-errors

[Refactor] Handle `cuda:0` device assignment
parents 59aef189 b3598058
......@@ -93,11 +93,16 @@ class HFLM(LM):
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
accelerator = Accelerator()
if gpus <= 1 and not parallelize:
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device:
if device not in ["cuda", "cpu"]:
if device not in device_list:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
......@@ -111,7 +116,7 @@ class HFLM(LM):
)
else:
eval_logger.info(
f"Passed device '{device}', but using `accelerate launch` or `parallelize=True`. This will be overridden when placing model."
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device
......@@ -217,7 +222,6 @@ class HFLM(LM):
# multigpu data-parallel support when launched with accelerate
if gpus > 1:
accelerator = Accelerator()
if parallelize:
if accelerator.num_processes > 1:
raise RuntimeError(
......
......@@ -10,7 +10,7 @@ import collections
import importlib.util
import fnmatch
from typing import List, Union
from typing import List, Literal, Union
import gc
import torch
......@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment