Unverified Commit b2b749d1 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #7 from bghira/documentation/mps-xpu-example

add mps and xpu to examples
parents 840156b2 99f6e9ab
...@@ -25,6 +25,12 @@ Parler-TTS has light-weight dependencies and can be installed in one line: ...@@ -25,6 +25,12 @@ Parler-TTS has light-weight dependencies and can be installed in one line:
pip install git+https://github.com/huggingface/parler-tts.git pip install git+https://github.com/huggingface/parler-tts.git
``` ```
Apple Silicon users will need to run a follow-up command to make use the nightly PyTorch (2.4) build for bfloat16 support:
```sh
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
```
## Usage ## Usage
> [!TIP] > [!TIP]
...@@ -38,9 +44,16 @@ from transformers import AutoTokenizer ...@@ -38,9 +44,16 @@ from transformers import AutoTokenizer
import soundfile as sf import soundfile as sf
import torch import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu" device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.backends.mps.is_available():
device = "mps"
if torch.xpu.is_available():
device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device) model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1") tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
prompt = "Hey, how are you doing today?" prompt = "Hey, how are you doing today?"
...@@ -49,7 +62,7 @@ description = "A female speaker with a slightly low-pitched voice delivers her w ...@@ -49,7 +62,7 @@ description = "A female speaker with a slightly low-pitched voice delivers her w
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze() audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate) sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment