Unverified Commit 3d8bd119 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix end to end compilation (#32465)

parent 6a03942d
...@@ -1024,19 +1024,22 @@ class StaticCache(Cache): ...@@ -1024,19 +1024,22 @@ class StaticCache(Cache):
# Note: There will be significant perf decrease if switching to use 5D tensors instead. # Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers): for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers. new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) # Notes:
key_cache = getattr(self, f"key_cache_{idx}") # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
value_cache = getattr(self, f"value_cache_{idx}") # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # it is not needed anyway)
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case # 2. `torch.export()` requires mutations to be registered as buffers.
# it is not needed anyway)
if not is_torchdynamo_compiling(): if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(key_cache) self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
torch._dynamo.mark_static_address(value_cache) self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.key_cache.append(key_cache) new_layer_key_cache = getattr(self, f"key_cache_{idx}")
self.value_cache.append(value_cache) new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def update( def update(
self, self,
......
...@@ -1429,7 +1429,9 @@ class GenerationMixin: ...@@ -1429,7 +1429,9 @@ class GenerationMixin:
model_kwargs["cache_position"] = cache_position model_kwargs["cache_position"] = cache_position
return model_kwargs return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
""" """
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache or uses a different batch size. new `generate` call requires a larger cache or uses a different batch size.
...@@ -1477,7 +1479,7 @@ class GenerationMixin: ...@@ -1477,7 +1479,7 @@ class GenerationMixin:
"config": self.config, "config": self.config,
"max_batch_size": max_batch_size, "max_batch_size": max_batch_size,
"max_cache_len": max_cache_len, "max_cache_len": max_cache_len,
"device": self.device, "device": device,
"dtype": cache_dtype, "dtype": cache_dtype,
} }
self._cache = cache_cls(**cache_kwargs) self._cache = cache_cls(**cache_kwargs)
...@@ -1813,12 +1815,11 @@ class GenerationMixin: ...@@ -1813,12 +1815,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981" "issue: https://github.com/huggingface/transformers/issues/28981"
) )
model_kwargs[cache_name] = self._get_cache( model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation, cache_implementation=generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
* getattr(generation_config, "num_return_sequences", 1) max_cache_len=generation_config.max_length,
* batch_size, device=device,
generation_config.max_length, model_kwargs=model_kwargs,
model_kwargs,
) )
elif generation_config.cache_implementation == "quantized": elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache: if not self._supports_quantized_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment