Unverified Commit 4a782f46 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[AudioLDM2] Fix cache pos for GPT-2 generation (#8964)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent cdd12bde
......@@ -286,6 +286,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
The sequence of generated hidden-states.
"""
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment