Unverified Commit 11b209e1 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Architecture improvements (#65)



* add RoPe

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* better eval + add right padding + fix eval loss compute

* correct README

* correct config docstrings

* remove comment

* make style

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarsang-nguyen-ts <sang.nguyen@trustingsocial.com>
Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
parent 8b8c576e
......@@ -53,7 +53,8 @@ if torch.xpu.is_available():
device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1", torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
prompt = "Hey, how are you doing today?"
......
......@@ -60,8 +60,8 @@ if __name__ == "__main__":
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1
model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
......@@ -58,4 +58,7 @@ if __name__ == "__main__":
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
......@@ -60,8 +60,8 @@ if __name__ == "__main__":
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0
model.config.pad_token_id = encodec_vocab_size
model.config.decoder_start_token_id = encodec_vocab_size+1
model.config.decoder_start_token_id = encodec_vocab_size + 1
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/"))
......@@ -2,6 +2,7 @@ import dac
from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......
......@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Number of decoder layers.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer block.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
num_cross_attention_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
If it is not specified, will default to `num_key_value_heads`.
ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
......@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied.
rope_embeddings (`bool`, *optional*, defaults to `False`):
Whether to use ROPE or absolute positional embeddings.
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
cross_attention_implementation_strategy (`str`, *optional*):
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
"""
model_type = "parler_tts_decoder"
......@@ -86,6 +103,8 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
num_hidden_layers=24,
ffn_dim=4096,
num_attention_heads=16,
num_key_value_heads=None,
num_cross_attention_key_value_heads=None,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
......@@ -100,6 +119,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
bos_token_id=2049,
eos_token_id=2048,
tie_word_embeddings=False,
rope_embeddings=False,
rope_theta=10_000.0,
cross_attention_implementation_strategy=None,
**kwargs,
):
self.vocab_size = vocab_size
......@@ -108,6 +130,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if num_cross_attention_key_value_heads is None:
num_cross_attention_key_value_heads = num_key_value_heads
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
......@@ -117,6 +145,9 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks
self.rope_embeddings = rope_embeddings
self.rope_theta = rope_theta
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
super().__init__(
pad_token_id=pad_token_id,
......@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
Whether to use cross-attention conditioning for the prompt (as well as the description).
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
......@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
model_type = "parler_tts"
is_composition = True
def __init__(self, vocab_size=1024, **kwargs):
def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
......@@ -204,6 +237,7 @@ class ParlerTTSConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder")
self.vocab_size = vocab_size
self.prompt_cross_attention = prompt_cross_attention
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
......@@ -236,3 +270,21 @@ class ParlerTTSConfig(PretrainedConfig):
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
# Copy from musicgen
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
This diff is collapsed.
......@@ -78,6 +78,22 @@ class ModelArguments:
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
attn_implementation: str = field(
default="eager",
metadata={"help": "Attention implementation used. One of `eager`, `sdpa`, `flash_attention_2`"},
)
cross_attention_implementation_strategy: str = field(
default=None,
metadata={
"help": "If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation."
},
)
prompt_padding_side: Optional[str] = field(
default="left",
metadata={
"help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left."
},
)
@dataclass
......@@ -290,6 +306,10 @@ class DataTrainingArguments:
},
)
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
save_codec_steps: Optional[int] = field(
default=500,
metadata={"help": "Temporarily save the audio labels every `save_steps`."},
)
pad_to_multiple_of: Optional[int] = field(
default=2,
metadata={"help": ("Pad to multiple of for tokenizers.")},
......@@ -311,3 +331,32 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
default=8,
metadata={"help": ("Specify the batch size of the audio encoding pre-processing steps.")},
)
eval_dataloader_num_workers: Optional[int] = field(
default=0,
metadata={
"help": (
"Number of subprocesses to use for evaluation data loading (PyTorch only). 0 means that the data will be loaded in the main process."
)
},
)
compute_clap_similarity_metric: bool = field(
default=True,
metadata={
"help": (
"Whether or not to compute the clap similarity metric between the description and the generation during evalution."
)
},
)
compute_noise_level_metric: bool = field(
default=True,
metadata={"help": ("Whether or not to compute the squim si-sdr measure of the generations.")},
)
noise_level_to_compute_clean_wer: float = field(
default=25,
metadata={
"help": (
"if `compute_noise_level_metric=True`, will compute a 'clean' WER on samples with generated noise higher than `noise_level_to_compute_clean_wer`."
"This is a proxy measure to compute WER on clean audios, provided that the model learn to generate clean audios."
)
},
)
......@@ -30,6 +30,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods
audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios]
if self.max_length is not None:
audios = [audio[: min(l, self.max_length)] for audio, l in zip(audios, len_audio)]
# since resampling has already been performed in the 'load_multiple_datasets' function,
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
......@@ -81,7 +83,9 @@ class DataCollatorParlerTTSWithPadding:
# (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
if self.audio_max_length is not None and self.padding == "max_length":
labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)))
labels = torch.nn.functional.pad(
labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)), value=-100
)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
......@@ -95,11 +99,6 @@ class DataCollatorParlerTTSWithPadding:
batch = {"labels": labels, **input_ids}
if self.audio_max_length is not None and self.padding == "max_length":
# if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(
prompt_input_ids,
......@@ -206,7 +205,7 @@ def load_multiple_datasets(
all_datasets = []
# iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
with accelerator.main_process_first():
with accelerator.local_main_process_first():
dataset = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
......@@ -242,7 +241,7 @@ def load_multiple_datasets(
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "parler-tts/mls_eng_10k":
if dataset_dict["name"] not in {"parler-tts/mls_eng_10k", "parler-tts/mls_eng"}:
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
......@@ -272,7 +271,10 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if id_column_name is not None and dataset_dict["name"] not in {
"parler-tts/mls_eng_10k",
"parler-tts/mls_eng",
}:
if (
len(
dataset.filter(
......@@ -304,7 +306,7 @@ def load_multiple_datasets(
seed=seed,
)
else:
with accelerator.main_process_first():
with accelerator.local_main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset
import torch
from torchaudio.pipelines import SQUIM_OBJECTIVE
import torchaudio
import evaluate
from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
from transformers import (
AutoModel,
AutoProcessor,
pipeline,
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperTokenizerFast,
)
from accelerate.utils.memory import release_memory
import numpy as np
def clap_similarity(clap_model_name_or_path, texts, audios, device):
def clap_similarity(clap_model_name_or_path, texts, audios, device, input_sampling_rate=44100):
clap = AutoModel.from_pretrained(clap_model_name_or_path)
clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path)
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
output_sampling_rate = clap_processor.feature_extractor.sampling_rate
if input_sampling_rate != output_sampling_rate:
audios = [
torchaudio.functional.resample(torch.from_numpy(audio), input_sampling_rate, output_sampling_rate).numpy()
for audio in audios
]
clap_inputs = clap_processor(
text=texts, audios=audios, padding=True, return_tensors="pt", sampling_rate=output_sampling_rate
).to(device)
clap.to(device)
with torch.no_grad():
text_features = clap.get_text_features(
......@@ -14,16 +34,52 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
)
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean()
cosine_sim = cosine_sim.to("cpu")
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
clap, clap_inputs, audio_features, text_features = release_memory(clap, clap_inputs, audio_features, text_features)
return cosine_sim
def si_sdr(audios, device, input_sampling_rate=44100):
max_audio_length = 15 * SQUIM_OBJECTIVE.sample_rate
model = SQUIM_OBJECTIVE.get_model().to((device))
output_sampling_rate = SQUIM_OBJECTIVE.sample_rate
if input_sampling_rate != output_sampling_rate:
audios = [
torchaudio.functional.resample(
torch.tensor(audio)[None, :].to(device).float(), input_sampling_rate, output_sampling_rate
)
for audio in audios
]
def apply_squim(waveform):
with torch.no_grad():
waveform = waveform[:, : min(max_audio_length, waveform.shape[1])]
_, _, sdr_sample = model(waveform)
sdr_sample = sdr_sample.cpu()[0]
return sdr_sample
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
si_sdrs = [apply_squim(audio) for audio in audios]
audios, model = release_memory(audios, model)
return si_sdrs
def wer(
asr_model_name_or_path,
prompts,
audios,
device,
per_device_eval_batch_size,
sampling_rate,
noise_level_to_compute_clean_wer,
si_sdr_measures,
):
metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0)
return_language = None
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
......@@ -47,7 +103,11 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
normalized_references = []
for pred, ref in zip(transcriptions, prompts):
normalizer = english_normalizer if return_language and pred["chunks"][0]["language"] == "english" else basic_normalizer
normalizer = (
english_normalizer
if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english"
else basic_normalizer
)
norm_ref = normalizer(ref)
if len(norm_ref) > 0:
norm_pred = normalizer(pred["text"])
......@@ -56,4 +116,21 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
return word_error, [t["text"] for t in transcriptions]
clean_word_error = None
noisy_word_error = None
percent_clean_samples = 0
if noise_level_to_compute_clean_wer and si_sdr_measures:
si_sdr_measures = np.array(si_sdr_measures)
mask = si_sdr_measures >= noise_level_to_compute_clean_wer
if mask.any():
clean_word_error = 100 * metric.compute(
predictions=np.array(normalized_predictions)[mask], references=np.array(normalized_references)[mask]
)
noisy_word_error = 100 * metric.compute(
predictions=np.array(normalized_predictions)[~mask], references=np.array(normalized_references)[~mask]
)
percent_clean_samples = mask.sum() / len(mask)
asr_pipeline.model.to("cpu")
asr_pipeline = release_memory(asr_pipeline)
return word_error, [t["text"] for t in transcriptions], clean_word_error, noisy_word_error, percent_clean_samples
This diff is collapsed.
......@@ -7,6 +7,7 @@ from typing import Dict, List
import torch
from wandb import Audio
from datasets import load_from_disk, concatenate_datasets
def list_field(default=None, metadata=None):
......@@ -14,6 +15,8 @@ def list_field(default=None, metadata=None):
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
CHECKPOINT_CODEC_PREFIX = "checkpoint"
_RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$")
def get_last_checkpoint(folder):
......@@ -60,6 +63,59 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
shutil.rmtree(checkpoint, ignore_errors=True)
def save_codec_checkpoint(output_dir, dataset, step):
checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}"
output_path = os.path.join(output_dir, checkpoint_path)
dataset.save_to_disk(output_path)
def load_codec_checkpoint(checkpoint_path):
dataset = load_from_disk(checkpoint_path)
return dataset
def sorted_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")]
for path in glob_checkpoints:
regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def load_all_codec_checkpoints(output_dir=None) -> List[str]:
"""Helper function to load and concat all checkpoints."""
checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir)
datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted]
datasets = concatenate_datasets(datasets, axis=0)
return datasets
def get_last_codec_checkpoint_step(folder) -> int:
if not os.path.exists(folder) or not os.path.isdir(folder):
os.makedirs(folder, exist_ok=True)
return 0
content = os.listdir(folder)
checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None]
if len(checkpoints) == 0:
return 0
last_checkpoint = os.path.join(
folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0]))
)
# Find num steps saved state string pattern
pattern = r"checkpoint-(\d+)"
match = re.search(pattern, last_checkpoint)
cur_step = int(match.group(1))
return cur_step
def log_metric(
accelerator,
metrics: Dict,
......@@ -86,6 +142,7 @@ def log_pred(
pred_prompts: List[str],
transcriptions: List[str],
audios: List[torch.Tensor],
si_sdr_measures: List[float],
sampling_rate: int,
step: int,
prefix: str = "eval",
......@@ -98,16 +155,33 @@ def log_pred(
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
prefix_pretty = prefix.replace("/", "-")
# convert str data to a wandb compatible format
str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
data=str_data[:num_lines],
step=step,
commit=False,
)
if si_sdr_measures is None:
# convert str data to a wandb compatible format
str_data = [
[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))
]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
data=str_data[:num_lines],
step=step,
commit=False,
)
else:
# convert str data to a wandb compatible format
str_data = [
[pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]]
for i in range(len(pred_descriptions))
]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"],
data=str_data[:num_lines],
step=step,
commit=False,
)
# wandb can only loads 100 audios per step
wandb_tracker.log(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment