"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "60892c55a4356c1d6da471dc13cbd6bf4d4ee804"
Unverified Commit 7ffe25f2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: end-to-end compilation (#30788)

* mvp

* added test (a few models need fixes)

* fix a few test cases

* test nits

* harder test 😈

* revert changes in stablelm

* test with improved condition

* add todo

* tmp commit

* merged with main

* nits

* add todo

* final corrections

* add docs for generation compilation

* docs nits

* add  tip

* PR suggestions

* add more details to the compilation docs

* fix cache positions

* cache is now init in generate; update docs

* tag test as flaky

* docs

* post rebase make fixup and other nits

* remove unintended changes

* whisper (encoder-decoder) not supported

* move token default updates to ; add tests for token defaults

* push changes

* manual rebase

* chameleon doesn't support this

* fix test_static_cache_mha_mqa_gqa (broken in another PR)

* docs: dynamic is better with end-to-end compilation
parent 49928892
......@@ -18,59 +18,109 @@ Basic inference is slow because LLMs have to be called repeatedly to generate th
This guide will show you how to use the optimization techniques available in Transformers to accelerate LLM inference.
> [!TIP]
> Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes more optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference.
> Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes deployment-oriented optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference.
## Static kv-cache and torch.compile
## Static kv-cache and `torch.compile`
During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you're recomputing the same kv values each time.
To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [torch.compile](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels.
To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels.
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.
> [!WARNING]
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and `torch.compile`. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.
There are three flavors of static kv-cache usage, depending on the complexity of your task:
1. Basic usage: simply set a flag in `generation_config` (recommended);
2. Advanced usage: handle a cache object for multi-turn generation or a custom generation loop;
3. Advanced usage: compile the entire `generate` function into a single graph, if having a single graph is relevant for you.
Select the correct tab below for further instructions on each of these flavors.
> [!TIP]
> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend!
<hfoptions id="static-kv">
<hfoption id="basic usage: generation_config">
For this example, let's use the [Gemma](https://hf.co/google/gemma-2b) model. All we need to do is to:
1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static";
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
And that's it!
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b", device_map="auto"
)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```
There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a 4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up is ~3x. The forward pass (which still gets 4x speed up) is only a part of the whole [`~GenerationMixin.generate`] code.
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of `torch.compile`, and you should be aware of the following:
1. If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation;
2. The first couple of calls of the compiled function are slower, as the function is being compiled.
<hfoptions id="static-kv">
<hfoption id="generation_config">
> [!WARNING]
> For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. See the advanced usage tab.
</hfoption>
<hfoption id="advanced usage: control Static Cache">
Access the model's `generation_config` attribute and set the `cache_implementation` to "static".
A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache.
```py
model.generation_config.cache_implementation = "static"
```
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
Call torch.compile on the model to compile the forward pass with the static kv-cache.
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
```py
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']
outputs = compiled_model.generate(**input_ids)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
# pass in the generated text and the same cache object to continue generation from where it left off. Optionally, in a
# multi-turn conversation, append the new user input to the generated text.
new_input_ids = outputs
outputs = model.generate(new_input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2. The speed of light is constant in all inertial reference frames. 3.']
```
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation.
</hfoption>
<hfoption id="Static Cache">
> [!TIP]
> If you want to reuse the same [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.
If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
......@@ -102,12 +152,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
return new_token
```
There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method:
There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
2. Call torch.compile on the model to compile the forward pass with the static kv-cache.
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
```py
......@@ -142,8 +189,34 @@ text
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
```
> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method
</hfoption>
<hfoption id="advanced usage: end-to-end generate compilation">
Compiling the entire `generate` function, in terms of code, is even simpler than in the basic usage: call `torch.compile` on `generate` to compile the entire function. No need to specify the use of the static cache: although it is compatible, dynamic cache (default) was faster in our benchmarks.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```
As a result, we compile not only the model forward pass, but also all input preparation, logit processor operations, and so on. The result should be a slightly `generate` call, compared to the basic usage example, and the compiled graph may be better suited to more exotic hardware devices or use cases. However, there are severe drawbacks in using this approach:
1. Compilation is much slower;
2. All parameterization of `generate` must be done through `generation_config`;
3. Many warnings and exceptions are suppressed -- we suggest testing with its uncompiled form first;
4. Although we are working on it, it is heavily feature restricted (for instance, at the time of writing, generation does not stop if an EOS token is selected).
</hfoption>
</hfoptions>
......
......@@ -9,7 +9,7 @@ import torch
from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging
from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging
if is_quanto_available():
......@@ -398,7 +398,6 @@ class DynamicCache(Cache):
def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
......@@ -821,11 +820,13 @@ class StaticCache(Cache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
......
......@@ -1144,7 +1144,7 @@ class GenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not self.can_generate():
if not is_torchdynamo_compiling() and not self.can_generate():
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
......@@ -1247,6 +1247,10 @@ class GenerationMixin:
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""
# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
......@@ -1376,20 +1380,12 @@ class GenerationMixin:
self.generation_config = new_generation_config
using_model_generation_config = True
generation_config = self.generation_config
using_model_generation_config = True
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if is_torchdynamo_compiling():
model_kwargs = kwargs
generate_attributes_in_kwargs = [
key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
]
if len(generate_attributes_in_kwargs) > 0:
raise ValueError(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})."
)
else:
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
......@@ -1402,30 +1398,40 @@ class GenerationMixin:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
else:
model_kwargs = kwargs
return generation_config, model_kwargs
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
past_length = 0
if model_kwargs.get("past_key_values") is not None:
cache = model_kwargs["past_key_values"]
past_length = 0
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
# TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
# end-to-end compilation will yield bad results because `cache_position` will be incorrect.
if not is_torchdynamo_compiling():
cache_position = cache_position[past_length:]
model_kwargs["cache_position"] = cache_position
return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
new `generate` call requires a larger cache or uses a different batch size.
Returns the resulting cache object.
"""
......@@ -1458,7 +1464,14 @@ class GenerationMixin:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
if not is_torchdynamo_compiling():
cache_dtype = self.dtype
else:
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.lm_head.weight.dtype
cache_kwargs = {
"config": self.config,
"max_batch_size": max_batch_size,
......@@ -1535,27 +1548,29 @@ class GenerationMixin:
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
"to obtain reliable results."
)
# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not "
"stop until the maximum length is reached. Depending on other flags, it may even crash."
)
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as "
"eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
"will not stop until the maximum length is reached. Depending on other flags, it may even crash."
)
# Update generation config with the updated special tokens tensors
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
......@@ -1764,6 +1779,12 @@ class GenerationMixin:
cache_name = "cache_params"
else:
cache_name = "past_key_values"
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
raise ValueError(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
"may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` "
"input argument."
)
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
raise ValueError(
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
......@@ -1840,7 +1861,7 @@ class GenerationMixin:
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)
if self.device.type != input_ids.device.type:
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
......@@ -2137,23 +2158,36 @@ class GenerationMixin:
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
def _has_unfinished_sequences(
self,
this_peer_finished: bool,
synced_gpus: bool,
device: torch.device,
cur_len: Optional[int] = None,
max_length: Optional[int] = None,
) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
# torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
# although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
# TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html)
if is_torchdynamo_compiling():
return cur_len < max_length
else:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
elif this_peer_finished:
return False
return True
return True
def heal_tokens(
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
......@@ -2885,6 +2919,7 @@ class GenerationMixin:
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
......@@ -2908,12 +2943,14 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
......@@ -2959,6 +2996,7 @@ class GenerationMixin:
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
......@@ -2979,6 +3017,7 @@ class GenerationMixin:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
......
......@@ -249,6 +249,7 @@ class DbrxConfig(PretrainedConfig):
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.num_key_value_heads = self.attn_config.kv_n_heads
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
......
......@@ -513,7 +513,10 @@ def require_read_token(fn):
@wraps(fn)
def _inner(*args, **kwargs):
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
if token is not None:
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return fn(*args, **kwargs)
else: # Allow running locally with the default token env variable
return fn(*args, **kwargs)
return _inner
......
......@@ -670,18 +670,20 @@ def is_torch_compile_available():
def is_torchdynamo_compiling():
if not is_torch_available():
return False
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
# hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
try:
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
if version.parse(_torch_version) >= version.parse("2.3.0"):
import torch
import torch
return torch.compiler.is_compiling()
else:
return torch.compiler.is_compiling()
except AttributeError:
try:
import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling()
except Exception:
return False
except Exception:
return False
def is_torch_tensorrt_fx_available():
......
......@@ -1802,6 +1802,58 @@ class GenerationTesterMixin:
with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
def test_generate_compile_fullgraph(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
# TODO (joao) -- fix and enable me :)
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
input_ids = inputs_dict["input_ids"].to(torch_device)
# creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
}
for model_inputs in input_ids_sets:
# dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs)
# eager static cache
torch.compiler.reset()
model.generation_config.cache_implementation = "static"
output_static = model.generate(model_inputs, **generation_kwargs)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
torch.compiler.reset()
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
......
......@@ -370,6 +370,11 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_batching_equivalence(self):
pass
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
def test_generate_compile_fullgraph(self):
pass
@require_torch
class ChameleonIntegrationTest(unittest.TestCase):
......
......@@ -368,6 +368,10 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_disk_offload_bin(self):
pass
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
def test_generate_compile_fullgraph(self):
pass
@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
......
......@@ -31,6 +31,7 @@ from parameterized import parameterized
import transformers
from transformers import WhisperConfig
from transformers.testing_utils import (
is_flaky,
is_pt_flax_cross_test,
require_flash_attn,
require_torch,
......@@ -1785,6 +1786,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
)
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
......
......@@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase):
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
......@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
......@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment