Unverified Commit 6c66c6c8 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add warning in `generate` & `device_map=auto` & half precision models (#19468)



* fix device mismatch

* make fixup

* added slow tests

- added slow tests on `bnb` models to make sure generate works correctly

* replace with `self.device`

* revert force device assign

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* set the warning in `generate` instead of `sample`
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a3008c5a
...@@ -1349,6 +1349,17 @@ class GenerationMixin: ...@@ -1349,6 +1349,17 @@ class GenerationMixin:
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
) )
if self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
" Please make sure that you have put `input_ids` to the"
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
" running `.generate()`.",
UserWarning,
)
# 7. prepare distribution pre_processing samplers # 7. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
...@@ -1976,7 +1987,6 @@ class GenerationMixin: ...@@ -1976,7 +1987,6 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```""" ```"""
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment