Unverified Commit e03a9cc0 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Modify device_map behavior when loading a model using from_pretrained (#23922)



* Modify device map behavior for 4/8 bits model

* Remove device_map arg for training 4/8 bit model

* Remove index
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Add Exceptions

* Modify comment
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix formatting

* Get current device with accelerate

* Revert "Get current device with accelerate"

This reverts commit 46f00799103bbe15bd58762ba029aab35363c4f7.

* Fix Exception

* Modify quantization doc

* Fix error
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d1fa349e
...@@ -86,6 +86,7 @@ With this integration we were able to load large models on smaller devices and r ...@@ -86,6 +86,7 @@ With this integration we were able to load large models on smaller devices and r
<Tip warning={true}> <Tip warning={true}>
Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub except if you use the latest `transformers` and `bitsandbytes`. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section. Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub except if you use the latest `transformers` and `bitsandbytes`. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.
Note also that `device_map` is optional but setting `device_map = 'auto'` is prefered for inference as it will dispatch efficiently the model on the available ressources.
</Tip> </Tip>
...@@ -162,9 +163,10 @@ You can load a quantized model from the Hub by using `from_pretrained` method. M ...@@ -162,9 +163,10 @@ You can load a quantized model from the Hub by using `from_pretrained` method. M
```python ```python
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("{your_username}/bloom-560m-8bit") model = AutoModelForCausalLM.from_pretrained("{your_username}/bloom-560m-8bit", device_map="auto")
``` ```
Note that in this case, you don't need to specify the arguments `load_in_8bit=True` and `device_map="auto"`, but you need to make sure that `bitsandbytes` and `accelerate` are installed. Note that in this case, you don't need to specify the arguments `load_in_8bit=True`, but you need to make sure that `bitsandbytes` and `accelerate` are installed.
Note also that `device_map` is optional but setting `device_map = 'auto'` is prefered for inference as it will dispatch efficiently the model on the available ressources.
### Advanced usecases ### Advanced usecases
...@@ -253,6 +255,8 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) ...@@ -253,6 +255,8 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been loaded in 8-bit. With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been loaded in 8-bit.
This enables fine-tuning large models such as `flan-t5-large` or `facebook/opt-6.7b` in a single google Colab. Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details. This enables fine-tuning large models such as `flan-t5-large` or `facebook/opt-6.7b` in a single google Colab. Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details.
Note that you don't need to pass `device_map` when loading the model for training. It will automatically load your model on your GPU. You can also set the device map to a specific device if needed (e.g. `cuda:0`, `0`, `torch.device('cuda:0')`). Please note that `device_map=auto` should be used for inference only.
### BitsAndBytesConfig ### BitsAndBytesConfig
[[autodoc]] BitsAndBytesConfig [[autodoc]] BitsAndBytesConfig
......
...@@ -65,6 +65,7 @@ from transformers import AutoModelForCausalLM ...@@ -65,6 +65,7 @@ from transformers import AutoModelForCausalLM
model_name = "bigscience/bloom-2b5" model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True) model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
``` ```
Note that `device_map` is optional but setting `device_map = 'auto'` is prefered for inference as it will dispatch efficiently the model on the available ressources.
### Running FP4 models - multi GPU setup ### Running FP4 models - multi GPU setup
......
...@@ -2055,10 +2055,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2055,10 +2055,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
</Tip> </Tip>
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device. same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
like `1`) on which the model will be allocated, the device map will map the entire model to this
device. Passing `device_map = 0` means put the whole model on GPU 0.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device more information about each option see [designing a device
...@@ -2186,6 +2188,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2186,6 +2188,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
" ignored." " ignored."
) )
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}
if device_map is not None: if device_map is not None:
if low_cpu_mem_usage is None: if low_cpu_mem_usage is None:
low_cpu_mem_usage = True low_cpu_mem_usage = True
...@@ -2226,12 +2248,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2226,12 +2248,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"`quantization_config` argument at the same time." "`quantization_config` argument at the same time."
) )
# in the case a user loads an 8bit model from the Hub and assigns a new quantization_config
if device_map is None:
device_map = "auto"
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if load_in_8bit or load_in_4bit: if load_in_8bit or load_in_4bit:
if not (is_accelerate_available() and is_bitsandbytes_available()): if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError( raise ImportError(
...@@ -2251,10 +2267,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2251,10 +2267,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype = torch.float16 torch_dtype = torch.float16
if device_map is None: if device_map is None:
raise ValueError( if torch.cuda.is_available():
"A device map needs to be passed to run convert models into 8-bit and 4-bit formats. Please run" device_map = {"": torch.cuda.current_device()}
"`.from_pretrained` with `device_map='auto'`" else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
logger.info(
"The device_map was not initialized."
"Setting device_map to {'':torch.cuda.current_device()}."
"If you want to use the model for inference, please set device_map ='auto' "
) )
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if from_tf or from_flax: if from_tf or from_flax:
raise ValueError( raise ValueError(
"Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
...@@ -2318,10 +2342,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2318,10 +2342,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if torch_dtype is None: if torch_dtype is None:
torch_dtype = torch.float16 torch_dtype = torch.float16
if device_map is None: if device_map is None:
device_map = "auto" if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
logger.info(
"The device_map was not initialized."
"Setting device_map to {'':torch.cuda.current_device()}."
"If you want to use the model for inference, please set device_map ='auto' "
)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"): elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
logger.warning( logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct" "Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
......
...@@ -429,7 +429,9 @@ class Bnb4BitTestTraining(Base4bitTest): ...@@ -429,7 +429,9 @@ class Bnb4BitTestTraining(Base4bitTest):
return return
# Step 1: freeze all parameters # Step 1: freeze all parameters
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto") model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later param.requires_grad = False # freeze the model - train adapters later
......
...@@ -684,7 +684,9 @@ class MixedInt8TestTraining(BaseMixedInt8Test): ...@@ -684,7 +684,9 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
return return
# Step 1: freeze all parameters # Step 1: freeze all parameters
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later param.requires_grad = False # freeze the model - train adapters later
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment