Compute KL-divergence between the label probabilities of the generated audio with respect to the original audio.
Both generated audio (in eval_path) and original audio (in ref_path) are represented by the same prompt/description.
Audios are identified by an id, that is the name of the file in both directories and links the audio with the prompt/description.
segmenting the audio
For inputs longer that the 10 sec PaSST was trained on, we aggregate/collect via 'mean' (default) or 'max' pooling along the logits vector.
We split the inpot into overlapping analysis windows. Subsequently, we aggregate/collect (accross windows) the generated logits and then apply a softmax.
This evaluation script assumes that ids are in both ref_path and eval_path.
We label probabilities via the PaSST model: https://github.com/kkoutini/PaSST
GPU-based computation.
Extracting the probabilities is timeconsuming. After being computed once, we store them.
We store pre-computed reference probabilities in load/
To load those and save computation, just set the path in load_ref_probabilities.
If load_ref_probabilities is set, ref_path is not required.
Params:
-- ids: list of ids present in both eval_path and ref_path.
-- eval_path: path where the generated audio files to evaluate are available.
-- eval_files_extenstion: files extension (default .wav) in eval_path.
-- ref_path: path where the reference audio files are available. (instead of load_ref_probabilities)
-- ref_files_extenstion: files extension (default .wav) in ref_path.
-- load_ref_probabilities: path to the reference probabilities. (inestead of ref_path)
-- no_ids: it is possible that some reference audio is corrupted or not present. Ignore some this list of ids.
-- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along the logits vector.
Returns:
-- KL divergence
"""
withopen(os.devnull,'w')asf,contextlib.redirect_stdout(f):# capturing all useless outputs from passt
# load model
model=get_basic_model(mode="logits")
model.eval()
model=model.cuda()
ifnotos.path.isdir(eval_path):
ifnotos.path.isfile(eval_path):
raiseValueError('eval_path does not exist')
ifload_ref_probabilities:
ifnotos.path.exists(load_ref_probabilities):
raiseValueError('load_ref_probabilities does not exist')