"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d5239bab5bd87d8f719aff7827edee63aad51f60"
Unverified Commit f8394268 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`pipeline`] A simple fix for half-precision & 8bit models (#21479)



* v1 fix

* adapt from suggestions

* make style

* fix tests

* add gpu tests

* update docs

* fix other tests

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* better fix

* make fixup

* better example

* revert changes

* proposal

* more elegant solution

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 97d3390f
...@@ -105,6 +105,8 @@ If the model is too large for a single GPU, you can set `device_map="auto"` to a ...@@ -105,6 +105,8 @@ If the model is too large for a single GPU, you can set `device_map="auto"` to a
generator(model="openai/whisper-large", device_map="auto") generator(model="openai/whisper-large", device_map="auto")
``` ```
Note that if `device_map="auto"` is passed, there is no need to add the argument `device=device` when instantiating your `pipeline` as you may encounter some unexpected behavior!
### Batch size ### Batch size
By default, pipelines will not batch inference for reasons explained in detail [here](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching). The reason is that batching is not necessarily faster, and can actually be quite slower in some cases. By default, pipelines will not batch inference for reasons explained in detail [here](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching). The reason is that batching is not necessarily faster, and can actually be quite slower in some cases.
...@@ -257,4 +259,32 @@ sudo apt install -y tesseract-ocr ...@@ -257,4 +259,32 @@ sudo apt install -y tesseract-ocr
pip install pytesseract pip install pytesseract
``` ```
</Tip> </Tip>
\ No newline at end of file
## Using `pipeline` on large models with 🤗 `accelerate`:
You can easily run `pipeline` on large models using 🤗 `accelerate`! First make sure you have installed `accelerate` with `pip install accelerate`.
First load your model using `device_map="auto"`! We will use `facebook/opt-1.3b` for our example.
```py
# pip install accelerate
import torch
from transformers import pipeline
pipe = pipeline(model="facebook/opt-1.3b", torch_dtype=torch.bfloat16, device_map="auto")
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
```
You can also pass 8-bit loaded models if you install `bitsandbytes` and add the argument `load_in_8bit=True`
```py
# pip install accelerate bitsandbytes
import torch
from transformers import pipeline
pipe = pipeline(model="facebook/opt-1.3b", device_map="auto", model_kwargs={"load_in_8bit": True})
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
```
Note that you can replace the checkpoint with any of the Hugging Face model that supports large model loading such as BLOOM!
\ No newline at end of file
...@@ -738,6 +738,11 @@ def pipeline( ...@@ -738,6 +738,11 @@ def pipeline(
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those' 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
" arguments might conflict, use only one.)" " arguments might conflict, use only one.)"
) )
if device is not None:
logger.warning(
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
)
model_kwargs["device_map"] = device_map model_kwargs["device_map"] = device_map
if torch_dtype is not None: if torch_dtype is not None:
if "torch_dtype" in model_kwargs: if "torch_dtype" in model_kwargs:
......
...@@ -286,9 +286,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -286,9 +286,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
installed. If no framework is specified, will default to the one currently installed. If no framework is installed. If no framework is specified, will default to the one currently installed. If no framework is
specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if
no model is provided. no model is provided.
device (`int`, *optional*, defaults to -1): device (Union[`int`, `torch.device`], *optional*):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
the associated CUDA device id. model on the associated CUDA device id.
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
[PyCTCDecode's [PyCTCDecode's
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
......
...@@ -749,7 +749,7 @@ class Pipeline(_ScikitCompat): ...@@ -749,7 +749,7 @@ class Pipeline(_ScikitCompat):
framework: Optional[str] = None, framework: Optional[str] = None,
task: str = "", task: str = "",
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: Union[int, str, "torch.device"] = -1, device: Union[int, str, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None, torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False, binary_output: bool = False,
**kwargs, **kwargs,
...@@ -764,6 +764,19 @@ class Pipeline(_ScikitCompat): ...@@ -764,6 +764,19 @@ class Pipeline(_ScikitCompat):
self.image_processor = image_processor self.image_processor = image_processor
self.modelcard = modelcard self.modelcard = modelcard
self.framework = framework self.framework = framework
if self.framework == "pt" and device is not None:
self.model = self.model.to(device=device)
if device is None:
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
else:
device = -1
if is_torch_available() and self.framework == "pt": if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device): if isinstance(device, torch.device):
self.device = device self.device = device
...@@ -774,14 +787,10 @@ class Pipeline(_ScikitCompat): ...@@ -774,14 +787,10 @@ class Pipeline(_ScikitCompat):
else: else:
self.device = torch.device(f"cuda:{device}") self.device = torch.device(f"cuda:{device}")
else: else:
self.device = device self.device = device if device is not None else -1
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
self.binary_output = binary_output self.binary_output = binary_output
# Special handling
if self.framework == "pt" and self.device.type != "cpu":
self.model = self.model.to(self.device)
# Update config with task specific parameters # Update config with task specific parameters
task_specific_params = self.model.config.task_specific_params task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params: if task_specific_params is not None and task in task_specific_params:
......
...@@ -255,7 +255,6 @@ class QuestionAnsweringPipeline(ChunkPipeline): ...@@ -255,7 +255,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
device: int = -1,
task: str = "", task: str = "",
**kwargs, **kwargs,
): ):
...@@ -264,7 +263,6 @@ class QuestionAnsweringPipeline(ChunkPipeline): ...@@ -264,7 +263,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
tokenizer=tokenizer, tokenizer=tokenizer,
modelcard=modelcard, modelcard=modelcard,
framework=framework, framework=framework,
device=device,
task=task, task=task,
**kwargs, **kwargs,
) )
......
...@@ -312,3 +312,12 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -312,3 +312,12 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16) pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
pipe("This is a test") pipe("This is a test")
@require_torch
@require_accelerate
@require_torch_gpu
def test_pipeline_accelerate_top_p(self):
import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
pipe("This is a test", do_sample=True, top_p=0.5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment