Unverified Commit 9af1b6a8 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Musicgen special tokens in tensors (#31420)

fix
parent eed9ed67
...@@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = input_ids.shape[0] // self.num_codebooks batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
...@@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
......
...@@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ...@@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = input_ids.shape[0] // self.num_codebooks batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
...@@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment