"include/ck/utility/functional2.hpp" did not exist on "0c05f4279f65f35185d2335da27118f327ebfc62"
Unverified Commit 862f8418 authored by eustlb's avatar eustlb Committed by GitHub
Browse files

Add static cache (#89)



* add rope

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* unpin trfms

* remove CFG

* imports and constants
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* attention modifications to handle static cach
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* decoder layer modification to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSPreTrainedModel modifs to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSDecoder modifs to handle static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSModel + ParlerTTSForCausalLM modfis
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ParlerTTSForConditionalGeneration modifs
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* decoder_attention_mask for static cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* create inputs_embeds early to have a good cache initialization
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* _get_cache method
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* init the cache
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* ensure good device
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* pin tfrms version
Co-Authored-By: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>

* fix attention_mask FA2

* remove unnecessary method

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unnecessary imports

* replace the hardcoded cache_position with a more elegant approach

* make style

* unpin transformers

* pin transformers

* pin torch

* refactor + unpin torch

* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* update training script to match 11b209e1



* Update parler_tts/modeling_parler_tts.py
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms

* fix input_ids_length

* warning full attention mask creation

* changes for training compatibility

---------
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: default avatarYoach Lacombe <yoach.lacombe@gmail.com>
Co-authored-by: default avatarYoach Lacombe <52246514+ylacombe@users.noreply.github.com>
Co-authored-by: default avatarsang-nguyen-ts <sang.nguyen@trustingsocial.com>
Co-authored-by: yoach@huggingface.co <Yoach Lacombe>
Co-authored-by: default avatarsang-nguyen-ts <sang-nguyen-ts@users.noreply.github.com>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 11b209e1
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
......@@ -57,7 +58,7 @@ css = """
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__":
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os
from transformers import AutoConfig
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
if __name__ == "__main__":
......
import dac
from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor
from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor
......
from transformers import AutoFeatureExtractor, AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
path = "TODO"
repo_id = "parler_tts_600M"
......
__version__ = "0.1"
from transformers import AutoConfig, AutoModel
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
......@@ -9,8 +12,6 @@ from .modeling_parler_tts import (
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
from transformers import PretrainedConfig
from typing import List
class DACConfig(PretrainedConfig):
......
import torch
from dac.model import DAC
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
from dac.model import DAC
from .configuration_dac import DACConfig
# model doesn't support batching yet
......@@ -134,4 +133,4 @@ class DACModel(PreTrainedModel):
return EncodecDecoderOutput(audio_values)
def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
raise ValueError("`DACModel.forward` not implemented yet")
This diff is collapsed.
......@@ -13,11 +13,12 @@
# limitations under the License.
import os
import setuptools
_deps = [
"transformers>=4.39.0,<4.41.0",
"transformers>=4.43.0,<=4.43.3",
"torch",
"sentencepiece",
"descript-audio-codec",
......@@ -60,7 +61,7 @@ setuptools.setup(
packages=setuptools.find_packages(),
install_requires=_deps,
extras_require={
"dev": [_extras_dev_deps],
"train": [_extras_training_deps],
"dev": _extras_dev_deps,
"train": _extras_training_deps,
},
)
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Union, Set
from typing import Dict, List, Optional, Set, Union
import torch
import numpy as np
import datasets
from datasets import load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
from transformers import AutoFeatureExtractor, AutoTokenizer
from tqdm import tqdm
import numpy as np
import torch
from accelerate import Accelerator
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from tqdm import tqdm
from transformers import AutoFeatureExtractor, AutoTokenizer
@dataclass
......
import os
import re
import shutil
from pathlib import Path
from dataclasses import field
from pathlib import Path
from typing import Dict, List
import torch
from datasets import concatenate_datasets, load_from_disk
from wandb import Audio
from datasets import load_from_disk, concatenate_datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment