"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "568a6a55794a96aa284b2955c396eef7e2e0d7e5"
Unverified Commit fa070e9d authored by Chia-Yu Hung's avatar Chia-Yu Hung Committed by GitHub
Browse files

Update tangoflux.py

parent 5eefc23f
...@@ -26,9 +26,11 @@ class TangoFluxInference: ...@@ -26,9 +26,11 @@ class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"): def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae') self.vae = AutoencoderOobleck()
paths = snapshot_download(repo_id=name) paths = snapshot_download(repo_id=name)
vae_weights = load_file("{}/vae.safetensors".format(paths))
self.vae.load_state_dict(vae_weights)
weights = load_file("{}/tangoflux.safetensors".format(paths)) weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f: with open('{}/config.json'.format(paths),'r') as f:
...@@ -51,7 +53,7 @@ class TangoFluxInference: ...@@ -51,7 +53,7 @@ class TangoFluxInference:
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0] wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
waveform_end = int(duration * self.vae.config.sampling_rate) waveform_end = int(duration * self.vae.config.sampling_rate)
wave = wave[:, :, :waveform_end] wave = wave[:, :waveform_end]
return wave return wave
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment