Unverified Commit c4c15917 authored by Huazhong Ji's avatar Huazhong Ji Committed by GitHub
Browse files

[HFLM]Use Accelerate's API to reduce hard-coded CUDA code (#1880)

parent b5afe229
...@@ -44,13 +44,13 @@ def _get_accelerate_args( ...@@ -44,13 +44,13 @@ def _get_accelerate_args(
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload", offload_folder: Optional[str] = "./offload",
gpus: Optional[int] = None,
) -> dict: ) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {} max_memory = {}
if max_memory_per_gpu is not None: if max_memory_per_gpu is not None:
max_memory_per_gpu_map = { max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu device_idx: max_memory_per_gpu for device_idx in range(gpus)
for device_idx in range(torch.cuda.device_count())
} }
max_memory.update(max_memory_per_gpu_map) max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None: if max_cpu_memory is not None:
...@@ -157,7 +157,7 @@ class HFLM(TemplateLM): ...@@ -157,7 +157,7 @@ class HFLM(TemplateLM):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(gpus)]
+ ["mps", "mps:0"] + ["mps", "mps:0"]
) )
if device and device in device_list: if device and device in device_list:
...@@ -216,6 +216,7 @@ class HFLM(TemplateLM): ...@@ -216,6 +216,7 @@ class HFLM(TemplateLM):
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
parallelize=parallelize, parallelize=parallelize,
gpus=gpus,
device_map_option=device_map_option, device_map_option=device_map_option,
max_memory_per_gpu=max_memory_per_gpu, max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory, max_cpu_memory=max_cpu_memory,
...@@ -330,9 +331,7 @@ class HFLM(TemplateLM): ...@@ -330,9 +331,7 @@ class HFLM(TemplateLM):
self._model = accelerator.prepare_model( self._model = accelerator.prepare_model(
self.model, evaluation_mode=True self.model, evaluation_mode=True
) )
self._device = torch.device( self._device = torch.device(f"{accelerator.device}")
f"cuda:{accelerator.local_process_index}"
)
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
...@@ -489,6 +488,7 @@ class HFLM(TemplateLM): ...@@ -489,6 +488,7 @@ class HFLM(TemplateLM):
# only used if `parallelize=True`. # only used if `parallelize=True`.
# (accelerate naive PP (device_map) options) # (accelerate naive PP (device_map) options)
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
gpus: Optional[int] = None,
device_map_option: Optional[str] = "auto", device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
...@@ -520,6 +520,7 @@ class HFLM(TemplateLM): ...@@ -520,6 +520,7 @@ class HFLM(TemplateLM):
max_memory_per_gpu, max_memory_per_gpu,
max_cpu_memory, max_cpu_memory,
offload_folder, offload_folder,
gpus,
) )
) )
elif "device_map" not in model_kwargs: elif "device_map" not in model_kwargs:
...@@ -528,9 +529,7 @@ class HFLM(TemplateLM): ...@@ -528,9 +529,7 @@ class HFLM(TemplateLM):
# for quantized models now seems to be device_map="auto" # for quantized models now seems to be device_map="auto"
# which breaks data-parallel mode. # which breaks data-parallel mode.
if hasattr(self, "accelerator"): if hasattr(self, "accelerator"):
model_kwargs.update( model_kwargs.update({"device_map": {"": f"{self.accelerator.device}"}})
{"device_map": {"": f"cuda:{self.accelerator.local_process_index}"}}
)
else: else:
model_kwargs.update({"device_map": {"": str(self.device)}}) model_kwargs.update({"device_map": {"": str(self.device)}})
...@@ -583,7 +582,9 @@ class HFLM(TemplateLM): ...@@ -583,7 +582,9 @@ class HFLM(TemplateLM):
if self._model.config.vocab_size != len(self.tokenizer): if self._model.config.vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens # resize model for LoRAs with added tokens
self._model.resize_token_embeddings(len(self.tokenizer)) self._model.resize_token_embeddings(len(self.tokenizer))
eval_logger.info(f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer...") eval_logger.info(
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
)
self._model = PeftModel.from_pretrained( self._model = PeftModel.from_pretrained(
self._model, peft, revision=revision self._model, peft, revision=revision
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment