Unverified Commit 7ffe25f2 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: end-to-end compilation (#30788)

* mvp

* added test (a few models need fixes)

* fix a few test cases

* test nits

* harder test 😈

* revert changes in stablelm

* test with improved condition

* add todo

* tmp commit

* merged with main

* nits

* add todo

* final corrections

* add docs for generation compilation

* docs nits

* add  tip

* PR suggestions

* add more details to the compilation docs

* fix cache positions

* cache is now init in generate; update docs

* tag test as flaky

* docs

* post rebase make fixup and other nits

* remove unintended changes

* whisper (encoder-decoder) not supported

* move token default updates to ; add tests for token defaults

* push changes

* manual rebase

* chameleon doesn't support this

* fix test_static_cache_mha_mqa_gqa (broken in another PR)

* docs: dynamic is better with end-to-end compilation
parent 49928892
...@@ -18,59 +18,109 @@ Basic inference is slow because LLMs have to be called repeatedly to generate th ...@@ -18,59 +18,109 @@ Basic inference is slow because LLMs have to be called repeatedly to generate th
This guide will show you how to use the optimization techniques available in Transformers to accelerate LLM inference. This guide will show you how to use the optimization techniques available in Transformers to accelerate LLM inference.
> [!TIP] > [!TIP]
> Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes more optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference. > Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes deployment-oriented optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference.
## Static kv-cache and torch.compile ## Static kv-cache and `torch.compile`
During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you're recomputing the same kv values each time. During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you're recomputing the same kv values each time.
To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [torch.compile](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels.
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.
> [!WARNING] > [!WARNING]
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list. > Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and `torch.compile`. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. There are three flavors of static kv-cache usage, depending on the complexity of your task:
1. Basic usage: simply set a flag in `generation_config` (recommended);
2. Advanced usage: handle a cache object for multi-turn generation or a custom generation loop;
3. Advanced usage: compile the entire `generate` function into a single graph, if having a single graph is relevant for you.
Select the correct tab below for further instructions on each of these flavors.
> [!TIP]
> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend!
<hfoptions id="static-kv">
<hfoption id="basic usage: generation_config">
For this example, let's use the [Gemma](https://hf.co/google/gemma-2b) model. All we need to do is to:
1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static";
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
And that's it!
```py ```py
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
"google/gemma-2b", device_map="auto"
) model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
``` ```
There are two ways you can configure the model to use a static kv-cache. For a 7B model on an A100, both methods get a 4x speed up in the forward pass. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware. If you're using the [`~GenerationMixin.generate`] method, the speed up is ~3x. The forward pass (which still gets 4x speed up) is only a part of the whole [`~GenerationMixin.generate`] code. Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of `torch.compile`, and you should be aware of the following:
1. If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation;
2. The first couple of calls of the compiled function are slower, as the function is being compiled.
<hfoptions id="static-kv"> > [!WARNING]
<hfoption id="generation_config"> > For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. See the advanced usage tab.
</hfoption>
<hfoption id="advanced usage: control Static Cache">
Access the model's `generation_config` attribute and set the `cache_implementation` to "static". A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache.
```py ```py
model.generation_config.cache_implementation = "static" from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
``` import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
Call torch.compile on the model to compile the forward pass with the static kv-cache. tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
```py model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states " input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]
model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']
outputs = compiled_model.generate(**input_ids) # pass in the generated text and the same cache object to continue generation from where it left off. Optionally, in a
tokenizer.batch_decode(outputs, skip_special_tokens=True) # multi-turn conversation, append the new user input to the generated text.
['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] new_input_ids = outputs
outputs = model.generate(new_input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2. The speed of light is constant in all inertial reference frames. 3.']
``` ```
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation. > [!TIP]
> If you want to reuse the same [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
</hfoption>
<hfoption id="Static Cache">
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache. If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
```py ```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
...@@ -102,12 +152,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu ...@@ -102,12 +152,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
return new_token return new_token
``` ```
There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method: There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. 1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
2. Call torch.compile on the model to compile the forward pass with the static kv-cache.
3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more. 3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
```py ```py
...@@ -142,8 +189,34 @@ text ...@@ -142,8 +189,34 @@ text
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
``` ```
> [!TIP] </hfoption>
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method <hfoption id="advanced usage: end-to-end generate compilation">
Compiling the entire `generate` function, in terms of code, is even simpler than in the basic usage: call `torch.compile` on `generate` to compile the entire function. No need to specify the use of the static cache: although it is compatible, dynamic cache (default) was faster in our benchmarks.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```
As a result, we compile not only the model forward pass, but also all input preparation, logit processor operations, and so on. The result should be a slightly `generate` call, compared to the basic usage example, and the compiled graph may be better suited to more exotic hardware devices or use cases. However, there are severe drawbacks in using this approach:
1. Compilation is much slower;
2. All parameterization of `generate` must be done through `generation_config`;
3. Many warnings and exceptions are suppressed -- we suggest testing with its uncompiled form first;
4. Although we are working on it, it is heavily feature restricted (for instance, at the time of writing, generation does not stop if an EOS token is selected).
</hfoption> </hfoption>
</hfoptions> </hfoptions>
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from packaging import version from packaging import version
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging from .utils import is_hqq_available, is_quanto_available, is_torchdynamo_compiling, logging
if is_quanto_available(): if is_quanto_available():
...@@ -398,7 +398,6 @@ class DynamicCache(Cache): ...@@ -398,7 +398,6 @@ class DynamicCache(Cache):
def crop(self, max_length: int): def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative # In case it is negative
if max_length < 0: if max_length < 0:
max_length = self.get_seq_length() - abs(max_length) max_length = self.get_seq_length() - abs(max_length)
...@@ -821,11 +820,13 @@ class StaticCache(Cache): ...@@ -821,11 +820,13 @@ class StaticCache(Cache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers): for _ in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache) if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache) self.value_cache.append(new_layer_value_cache)
......
...@@ -1144,7 +1144,7 @@ class GenerationMixin: ...@@ -1144,7 +1144,7 @@ class GenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use. right class to use.
""" """
if not self.can_generate(): if not is_torchdynamo_compiling() and not self.can_generate():
generate_compatible_mappings = [ generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
...@@ -1247,6 +1247,10 @@ class GenerationMixin: ...@@ -1247,6 +1247,10 @@ class GenerationMixin:
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length""" """Performs validation related to the resulting generated length"""
# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return
# 1. Max length warnings related to poor parameterization # 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config # 20 is the default max_length of the generation config
...@@ -1376,20 +1380,12 @@ class GenerationMixin: ...@@ -1376,20 +1380,12 @@ class GenerationMixin:
self.generation_config = new_generation_config self.generation_config = new_generation_config
using_model_generation_config = True using_model_generation_config = True
generation_config = self.generation_config generation_config = self.generation_config
using_model_generation_config = True
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
if is_torchdynamo_compiling(): # exception will be raised in `_validate_model_kwargs`
model_kwargs = kwargs if not is_torchdynamo_compiling():
generate_attributes_in_kwargs = [
key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
]
if len(generate_attributes_in_kwargs) > 0:
raise ValueError(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})."
)
else:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
...@@ -1402,30 +1398,40 @@ class GenerationMixin: ...@@ -1402,30 +1398,40 @@ class GenerationMixin:
generation_config.pad_token_id = self.generation_config.pad_token_id generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None: if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
else:
model_kwargs = kwargs
return generation_config, model_kwargs return generation_config, model_kwargs
def _get_initial_cache_position(self, input_ids, model_kwargs): def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
past_length = 0 past_length = 0
if model_kwargs.get("past_key_values") is not None: if model_kwargs.get("past_key_values") is not None:
cache = model_kwargs["past_key_values"] cache = model_kwargs["past_key_values"]
past_length = 0
if not isinstance(cache, Cache): if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2] past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length() past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs: # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
cur_len = model_kwargs["inputs_embeds"].shape[1] # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
else: if not is_torchdynamo_compiling():
cur_len = input_ids.shape[-1] cache_position = cache_position[past_length:]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
model_kwargs["cache_position"] = cache_position
return model_kwargs return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
""" """
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache. new `generate` call requires a larger cache or uses a different batch size.
Returns the resulting cache object. Returns the resulting cache object.
""" """
...@@ -1458,7 +1464,14 @@ class GenerationMixin: ...@@ -1458,7 +1464,14 @@ class GenerationMixin:
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype cache_dtype = self.config._pre_quantization_dtype
else: else:
cache_dtype = self.dtype if not is_torchdynamo_compiling():
cache_dtype = self.dtype
else:
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.lm_head.weight.dtype
cache_kwargs = { cache_kwargs = {
"config": self.config, "config": self.config,
"max_batch_size": max_batch_size, "max_batch_size": max_batch_size,
...@@ -1535,27 +1548,29 @@ class GenerationMixin: ...@@ -1535,27 +1548,29 @@ class GenerationMixin:
pad_token_tensor = eos_token_tensor[0] pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
"to obtain reliable results."
)
# Sanity checks/warnings # Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_tensor is None: if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
raise ValueError( raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
) )
if eos_token_tensor is not None and ( if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() if (
): eos_token_tensor is not None
logger.warning( and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not " ):
"stop until the maximum length is reached. Depending on other flags, it may even crash." if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
) logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as "
"eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
"will not stop until the maximum length is reached. Depending on other flags, it may even crash."
)
# Update generation config with the updated special tokens tensors # Update generation config with the updated special tokens tensors
# NOTE: this must be written into a different attribute name than the one holding the original special tokens # NOTE: this must be written into a different attribute name than the one holding the original special tokens
...@@ -1764,6 +1779,12 @@ class GenerationMixin: ...@@ -1764,6 +1779,12 @@ class GenerationMixin:
cache_name = "cache_params" cache_name = "cache_params"
else: else:
cache_name = "past_key_values" cache_name = "past_key_values"
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
raise ValueError(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
"may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` "
"input argument."
)
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
raise ValueError( raise ValueError(
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
...@@ -1840,7 +1861,7 @@ class GenerationMixin: ...@@ -1840,7 +1861,7 @@ class GenerationMixin:
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
) )
if self.device.type != input_ids.device.type: if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
warnings.warn( warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different" "You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
...@@ -2137,23 +2158,36 @@ class GenerationMixin: ...@@ -2137,23 +2158,36 @@ class GenerationMixin:
result.past_key_values = result.past_key_values.to_legacy_cache() result.past_key_values = result.past_key_values.to_legacy_cache()
return result return result
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: def _has_unfinished_sequences(
self,
this_peer_finished: bool,
synced_gpus: bool,
device: torch.device,
cur_len: Optional[int] = None,
max_length: Optional[int] = None,
) -> bool:
""" """
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly. fed through `this_peer_finished`. ZeRO stage 3-friendly.
""" """
if synced_gpus: # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
# The following logic allows an early break if all peers finished generating their sequence # TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html)
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) if is_torchdynamo_compiling():
# send 0.0 if we finished, 1.0 otherwise return cur_len < max_length
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) else:
# did all peers finish? the reduced sum will be 0.0 then if synced_gpus:
if this_peer_finished_flag.item() == 0.0: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False return False
elif this_peer_finished: return True
return False
return True
def heal_tokens( def heal_tokens(
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
...@@ -2885,6 +2919,7 @@ class GenerationMixin: ...@@ -2885,6 +2919,7 @@ class GenerationMixin:
output_scores = generation_config.output_scores output_scores = generation_config.output_scores
output_logits = generation_config.output_logits output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
...@@ -2908,12 +2943,14 @@ class GenerationMixin: ...@@ -2908,12 +2943,14 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
batch_size = input_ids.shape[0] batch_size, cur_len = input_ids.shape
this_peer_finished = False this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
...@@ -2959,6 +2996,7 @@ class GenerationMixin: ...@@ -2959,6 +2996,7 @@ class GenerationMixin:
# token selection # token selection
if do_sample: if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else: else:
next_tokens = torch.argmax(next_token_scores, dim=-1) next_tokens = torch.argmax(next_token_scores, dim=-1)
...@@ -2979,6 +3017,7 @@ class GenerationMixin: ...@@ -2979,6 +3017,7 @@ class GenerationMixin:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0 this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
# This is needed to properly delete outputs.logits which may be very large for first iteration # This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
......
...@@ -249,6 +249,7 @@ class DbrxConfig(PretrainedConfig): ...@@ -249,6 +249,7 @@ class DbrxConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.output_router_logits = output_router_logits self.output_router_logits = output_router_logits
self.num_key_value_heads = self.attn_config.kv_n_heads
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings: if tie_word_embeddings:
......
...@@ -513,7 +513,10 @@ def require_read_token(fn): ...@@ -513,7 +513,10 @@ def require_read_token(fn):
@wraps(fn) @wraps(fn)
def _inner(*args, **kwargs): def _inner(*args, **kwargs):
with patch("huggingface_hub.utils._headers.get_token", return_value=token): if token is not None:
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return fn(*args, **kwargs)
else: # Allow running locally with the default token env variable
return fn(*args, **kwargs) return fn(*args, **kwargs)
return _inner return _inner
......
...@@ -670,18 +670,20 @@ def is_torch_compile_available(): ...@@ -670,18 +670,20 @@ def is_torch_compile_available():
def is_torchdynamo_compiling(): def is_torchdynamo_compiling():
if not is_torch_available(): if not is_torch_available():
return False return False
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
# hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
try: try:
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible. import torch
if version.parse(_torch_version) >= version.parse("2.3.0"):
import torch
return torch.compiler.is_compiling() return torch.compiler.is_compiling()
else: except AttributeError:
try:
import torch._dynamo as dynamo # noqa: F401 import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling() return dynamo.is_compiling()
except Exception: except Exception:
return False return False
def is_torch_tensorrt_fx_available(): def is_torch_tensorrt_fx_available():
......
...@@ -1802,6 +1802,58 @@ class GenerationTesterMixin: ...@@ -1802,6 +1802,58 @@ class GenerationTesterMixin:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
def test_generate_compile_fullgraph(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
# TODO (joao) -- fix and enable me :)
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
input_ids = inputs_dict["input_ids"].to(torch_device)
# creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
}
for model_inputs in input_ids_sets:
# dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs)
# eager static cache
torch.compiler.reset()
model.generation_config.cache_implementation = "static"
output_static = model.generate(model_inputs, **generation_kwargs)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
torch.compiler.reset()
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences num_sequences_in_output = batch_size * num_return_sequences
......
...@@ -370,6 +370,11 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -370,6 +370,11 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_batching_equivalence(self): def test_batching_equivalence(self):
pass pass
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
def test_generate_compile_fullgraph(self):
pass
@require_torch @require_torch
class ChameleonIntegrationTest(unittest.TestCase): class ChameleonIntegrationTest(unittest.TestCase):
......
...@@ -368,6 +368,10 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -368,6 +368,10 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_disk_offload_bin(self): def test_disk_offload_bin(self):
pass pass
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
def test_generate_compile_fullgraph(self):
pass
@require_torch @require_torch
class DbrxModelIntegrationTest(unittest.TestCase): class DbrxModelIntegrationTest(unittest.TestCase):
......
...@@ -31,6 +31,7 @@ from parameterized import parameterized ...@@ -31,6 +31,7 @@ from parameterized import parameterized
import transformers import transformers
from transformers import WhisperConfig from transformers import WhisperConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky,
is_pt_flax_cross_test, is_pt_flax_cross_test,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
...@@ -1785,6 +1786,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -1785,6 +1786,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"] output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
) )
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
def test_custom_4d_attention_mask(self): def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32) model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
......
...@@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase): ...@@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase):
mha_config = LlamaConfig(num_attention_heads=32) mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update( cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)} *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128))
...@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase): ...@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update( cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128))
...@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase): ...@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update( cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment