"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "31c56f2e0b908a7f7f5669b3d535c65f156f6556"
Unverified Commit 914771cb authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[TTA Pipeline] Fix MusicGen test (#26348)

* fix musicgen pipeline test

* fix wav2vec2 doctest

* revert wav2vec2
parent 368a58e6
...@@ -71,13 +71,13 @@ class TextToAudioPipeline(Pipeline): ...@@ -71,13 +71,13 @@ class TextToAudioPipeline(Pipeline):
if self.sampling_rate is None: if self.sampling_rate is None:
# get sampling_rate from config and generation config # get sampling_rate from config and generation config
config = self.model.config.to_dict() config = self.model.config
gen_config = self.model.__dict__.get("generation_config", None) gen_config = self.model.__dict__.get("generation_config", None)
if gen_config is not None: if gen_config is not None:
config.update(gen_config.to_dict()) config.update(gen_config.to_dict())
for sampling_rate_name in ["sample_rate", "sampling_rate"]: for sampling_rate_name in ["sample_rate", "sampling_rate"]:
sampling_rate = config.get(sampling_rate_name, None) sampling_rate = getattr(config, sampling_rate_name, None)
if sampling_rate is not None: if sampling_rate is not None:
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment