Commit b8154374 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

handle device assignment better

parent d15ee17a
...@@ -58,7 +58,7 @@ class HFLM(LM): ...@@ -58,7 +58,7 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
device="cuda", device: Optional[Union[str, int]] = "cuda",
pretrained="gpt2", pretrained="gpt2",
revision="main", revision="main",
low_cpu_mem_usage=None, low_cpu_mem_usage=None,
...@@ -82,11 +82,16 @@ class HFLM(LM): ...@@ -82,11 +82,16 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator = Accelerator()
if gpus <= 1 and not parallelize: if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device: if device:
if device not in ["cuda", "cpu"]: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
...@@ -100,7 +105,7 @@ class HFLM(LM): ...@@ -100,7 +105,7 @@ class HFLM(LM):
) )
else: else:
eval_logger.info( eval_logger.info(
f"Passed device '{device}', but using `accelerate launch` or `parallelize=True`. This will be overridden when placing model." f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
) )
# TODO: include in warning that `load_in_8bit` etc. affect this too # TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device self._device = device
...@@ -162,7 +167,6 @@ class HFLM(LM): ...@@ -162,7 +167,6 @@ class HFLM(LM):
# multigpu data-parallel support when launched with accelerate # multigpu data-parallel support when launched with accelerate
if gpus > 1: if gpus > 1:
accelerator = Accelerator()
if parallelize: if parallelize:
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
raise RuntimeError( raise RuntimeError(
...@@ -170,34 +174,39 @@ class HFLM(LM): ...@@ -170,34 +174,39 @@ class HFLM(LM):
) )
else: else:
pass pass
elif gpus > accelerator.num_processes: else:
# TODO: make sure there's still never an edge case where we unintentionally default to CPU if gpus > accelerator.num_processes:
eval_logger.warning( # TODO: make sure there's still never an edge case where we unintentionally default to CPU
"WARNING: The number of total system GPUs does not match the number of spawned processes. " eval_logger.warning(
"If you would like to use data parallelism, please launch the script " "WARNING: The number of total system GPUs does not match the number of spawned processes. "
"with 'accelerate launch *script*'. " "If you would like to use data parallelism, please launch the script "
f"Current run will proceed with {accelerator.num_processes} devices." "with 'accelerate launch *script*'. "
) f"Current run will proceed with {accelerator.num_processes} devices."
)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else:
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {gpus} devices with data parallelism") eval_logger.info(f"Using {gpus} devices with data parallelism")
# manually set model to use gpu, for case where many GPUs available but
self._rank = self.accelerator.local_process_index # only seek to use one
self._world_size = self.accelerator.num_processes # self._device = (
# torch.device(f"cuda:{accelerator.local_process_index}")
# if torch.cuda.is_available()
# else torch.device("cpu")
# )
# self.model.to(self.device)
# else:
# self._model = accelerator.prepare(self.model)
# self._device = torch.device(f"cuda:{accelerator.local_process_index}")
# self.accelerator = accelerator
# self._rank = self.accelerator.local_process_index
# self._world_size = self.accelerator.num_processes
@property @property
def config(self): def config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment