Architecture improvements (#65)
* add RoPe * don't include padding in rope * possibly use cross-attn for prompt * fix rope * fix cross-attn * fix self-attn * fix dummy model * clean-up rope * first gqa implementation * fix wer eval * feat: add flash attention and spda * chore: add README for flash attention * chore: add benchmark script * chore: add benchmark attention approach * multi node and fix wer and fix compile * Update modeling_parler_tts.py * fix FA2, SDPA and add cross-attn MHA and attention type forcing * better cross_attention key values number of heads default + add training arguments for attn implementation * fix audio padding when torch compile or pad_to_max_length=True * correct multi node * make rope faster * fix encoder sdpa * fix training with cross attention + with FAZ * use fp32 as default model dtype + fix generation when using FA2 with autocast * remove redundant passes in generate + clean and fix attentions * fix edge case in WER evaluation when longform generation * better multi-node mapping and saving / add eval dataloader num workers * remove old benchmarks * faster audio encoding + checkpointing + fix generation step * better eval + add right padding + fix eval loss compute * correct README * correct config docstrings * remove comment * make style --------- Co-authored-by:sanchit-gandhi <sanchit@huggingface.co> Co-authored-by:
sang-nguyen-ts <sang.nguyen@trustingsocial.com> Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
Showing
This diff is collapsed.
This diff is collapsed.
Please register or sign in to comment