Unverified Commit 52e2c13d authored by Fakhir Ali's avatar Fakhir Ali Committed by GitHub
Browse files

[VITS] Fix speaker_embed device mismatch (#26115)

* [VITS] Fix speaker_embed device mismatch

- pass device arg to speaker_id tensor

* [VITS] put speaker_embed on device when int

* [VITS] device=self.device
instead of self.embed_speaker.weight.device

* [VITS] make tensor directly on device
using torch.full()
parent 098c3f40
...@@ -1435,7 +1435,9 @@ class VitsModel(VitsPreTrainedModel): ...@@ -1435,7 +1435,9 @@ class VitsModel(VitsPreTrainedModel):
if self.config.num_speakers > 1 and speaker_id is not None: if self.config.num_speakers > 1 and speaker_id is not None:
if not 0 <= speaker_id < self.config.num_speakers: if not 0 <= speaker_id < self.config.num_speakers:
raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
speaker_embeddings = self.embed_speaker(torch.tensor([speaker_id])).unsqueeze(-1) if isinstance(speaker_id, int):
speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
else: else:
speaker_embeddings = None speaker_embeddings = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment