Commit 99f6e9ab authored by bghira's avatar bghira
Browse files

add mps and xpu to examples

parent 10016fb0
......@@ -25,6 +25,12 @@ Parler-TTS has light-weight dependencies and can be installed in one line:
pip install git+https://github.com/huggingface/parler-tts.git
```
Apple Silicon users will need to run a follow-up command to make use the nightly PyTorch (2.4) build for bfloat16 support:
```sh
pip3 install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
```
## Usage
> [!TIP]
......@@ -38,9 +44,16 @@ from transformers import AutoTokenizer
import soundfile as sf
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.backends.mps.is_available():
device = "mps"
if torch.xpu.is_available():
device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device)
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
prompt = "Hey, how are you doing today?"
......@@ -49,7 +62,7 @@ description = "A female speaker with a slightly low-pitched voice delivers her w
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment