Unverified Commit ccc08978 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `device_map` directly in `pipeline(..)` function. (#17902)

* Adding support for `device_map` directly in `pipeline(..)` function.

* Updating the docstring.

* Adding a better docstring

* Put back type hints.

* Blacked. (`make fixup` didn't work ??!!)
parent fca66ec4
......@@ -389,6 +389,8 @@ def pipeline(
revision: Optional[str] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
device_map=None,
torch_dtype=None,
model_kwargs: Dict[str, Any] = None,
pipeline_class: Optional[Any] = None,
**kwargs
......@@ -472,6 +474,20 @@ def pipeline(
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
information](https://huggingface.co/docs/accelerate/main/en/big_modeling#accelerate.cpu_offload)
<Tip warning={true}>
Do not use `device_map` AND `device` at the same time as they will conflict
</Tip>
torch_dtype (`str` or `torch.dtype`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
**model_kwargs)` function.
......@@ -547,6 +563,20 @@ def pipeline(
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
if device_map is not None:
if "device_map" in model_kwargs:
raise ValueError(
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
" arguments might conflict, use only one.)"
)
model_kwargs["device_map"] = device_map
if torch_dtype is not None:
if "torch_dtype" in model_kwargs:
raise ValueError(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
model_kwargs["torch_dtype"] = torch_dtype
# Config is the primordial information item.
# Instantiate config if needed
......
......@@ -15,7 +15,13 @@
import unittest
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
from transformers.testing_utils import (
is_pipeline_test,
require_accelerate,
require_tf,
require_torch,
require_torch_gpu,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
......@@ -215,3 +221,63 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
@require_torch
@require_accelerate
@require_torch_gpu
def test_small_model_pt_bloom_accelerate(self):
import torch
# Classic `model_kwargs`
pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom",
model_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
)
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
[
{
"generated_text": (
"This is a test test test test test test test test test test test test test test test test"
" test"
)
}
],
)
# Upgraded those two to real pipeline arguments (they just get sent for the model as they're unlikely to mean anything else.)
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.bfloat16)
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
[
{
"generated_text": (
"This is a test test test test test test test test test test test test test test test test"
" test"
)
}
],
)
# torch_dtype not necessary
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
[
{
"generated_text": (
"This is a test test test test test test test test test test test test test test test test"
" test"
)
}
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment