Unverified Commit be2acc26 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Merge pull request #43 from sanchit-gandhi/fix-generation

[generation] use private greedy/sampling methods
parents 83d4a719 553d18f1
...@@ -1386,8 +1386,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1386,8 +1386,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
batch_size = input_ids.shape[0] // self.num_codebooks batch_size = input_ids.shape[0] // self.num_codebooks
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale model_kwargs["guidance_scale"] = generation_config.guidance_scale
...@@ -1481,14 +1479,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1481,14 +1479,11 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -1506,15 +1501,12 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1506,15 +1501,12 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -2198,8 +2190,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2198,8 +2190,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
self, self,
inputs_tensor: torch.Tensor, inputs_tensor: torch.Tensor,
model_kwargs, model_kwargs,
model_input_name: Optional[str] = None, model_input_name: Optional[str],
guidance_scale: Optional[float] = None, generation_config: GenerationConfig,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 1. get text encoder # 1. get text encoder
encoder = self.get_text_encoder() encoder = self.get_text_encoder()
...@@ -2221,6 +2213,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2221,6 +2213,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
encoder_kwargs = { encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
} }
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
guidance_scale = generation_config.guidance_scale
# 3. make sure that encoder returns `ModelOutput` # 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
...@@ -2452,8 +2447,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2452,8 +2447,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
batch_size = inputs_tensor.shape[0] batch_size = inputs_tensor.shape[0]
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale model_kwargs["guidance_scale"] = generation_config.guidance_scale
...@@ -2467,10 +2460,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2467,10 +2460,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, inputs_tensor, model_kwargs, model_input_name, generation_config,
model_kwargs,
model_input_name,
guidance_scale=generation_config.guidance_scale,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
...@@ -2579,14 +2569,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2579,14 +2569,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
...@@ -2605,15 +2592,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2605,15 +2592,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id, generation_config=generation_config,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
......
...@@ -17,7 +17,7 @@ import setuptools ...@@ -17,7 +17,7 @@ import setuptools
_deps = [ _deps = [
"transformers>=4.34.0", "transformers>=4.39.0,<4.41.0",
"torch", "torch",
"sentencepiece", "sentencepiece",
"descript-audio-codec", "descript-audio-codec",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment