Commit 39ac40a9 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2747 failed with stages
in 0 seconds
xlsx_sample_num: 5
dataset:
wenet-e2e/wenetspeech:
ratio: 1.0
data_paths:
- datasets/jsonl/wenet-e2e/wenetspeech/L_fixed.jsonl
- datasets/jsonl/wenet-e2e/wenetspeech/DEV_fixed.jsonl
Wenetspeech4TTS/Wenetspeech4TTS:
ratio: 1.0
data_paths:
- datasets/jsonl/Wenetspeech4TTS/WenetSpeech4TTS/Basic.jsonl
fixie-ai/librispeech_asr:
ratio: 1.0
data_paths:
- datasets/jsonl/fixie-ai/librispeech_asr/train.100.clean.jsonl
- datasets/jsonl/fixie-ai/librispeech_asr/train.360.clean.jsonl
- datasets/jsonl/fixie-ai/librispeech_asr/train.500.other.jsonl
mythicinfinity/libritts:
ratio: 1.0
data_paths:
- datasets/jsonl/mythicinfinity/libritts/train.clean.100.jsonl
- datasets/jsonl/mythicinfinity/libritts/train.clean.360.jsonl
- datasets/jsonl/mythicinfinity/libritts/train.other.500.jsonl
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.100.jsonl
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.360.jsonl
- datasets/jsonl/mythicinfinity/libritts_r/train.other.500.jsonl
parler-tts/mls_eng:
ratio: 1.0
data_paths:
#- datasets/jsonl/parler-tts/mls_eng_10k/train.jsonl
- datasets/jsonl/parler-tts/mls_eng/train.jsonl
mozilla-foundation/common_voice_17_0:
ratio: 1.0
data_paths:
- datasets/jsonl/mozilla-foundation/common_voice_17_0/en/train.jsonl
- datasets/jsonl/mozilla-foundation/common_voice_17_0/zh-CN/train.jsonl
MushanW/GLOBE_V2:
ratio: 1.0
data_paths:
- datasets/jsonl/MushanW/GLOBE_V2/train.jsonl
amphion/Emilia-Dataset:
ratio: 0.5
data_paths:
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200.jsonl
amphion/Emilia-Dataset/speaker_prompt:
ratio: 0.5
data_paths:
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200_speak_prompt.jsonl
openslr:
ratio: 1.0
data_paths:
- datasets/jsonl/openslr/SLR68/train.jsonl
- datasets/jsonl/openslr/SLR68/dev.jsonl
speechcolab/gigaspeech:
ratio: 1.0
data_paths:
- datasets/jsonl/speechcolab/gigaspeech/xl.jsonl
- datasets/jsonl/speechcolab/gigaspeech/dev.jsonl
MLCommons/peoples_speech:
ratio: 1.0
data_paths:
- datasets/jsonl/MLCommons/peoples_speech/clean.jsonl
- datasets/jsonl/MLCommons/peoples_speech/clean_sa.jsonl
- datasets/jsonl/MLCommons/peoples_speech/dirty.jsonl
- datasets/jsonl/MLCommons/peoples_speech/dirty_sa.jsonl
- datasets/jsonl/MLCommons/peoples_speech/validation.jsonl
facebook/voxpopuli:
ratio: 1.0
data_paths:
- datasets/jsonl/facebook/voxpopuli/en_train.jsonl
- datasets/jsonl/facebook/voxpopuli/en_accented_test.jsonl
shenyunhang:
ratio: 1.0
data_paths:
- datasets/jsonl/shenyunhang/AISHELL-1/train.jsonl
- datasets/jsonl/shenyunhang/AISHELL-1/dev.jsonl
- datasets/jsonl/shenyunhang/AISHELL-2/data.jsonl
- datasets/jsonl/shenyunhang/AISHELL-3/data.jsonl
- datasets/jsonl/shenyunhang/AISHELL-4/data.jsonl
gpt-omni/VoiceAssistant-400K:
ratio: 0.0
data_paths:
- datasets/jsonl/gpt-omni/VoiceAssistant-400K/data.jsonl
VITA-MLLM/AudioQA-1M:
ratio: 0.0
data_paths:
- datasets/jsonl/VITA-MLLM/AudioQA-1M/data.jsonl
BAAI/Infinity-Instruct:
ratio: 1.0
data_paths:
#- datasets/jsonl/BAAI/Infinity-Instruct/3M.jsonl
#- datasets/jsonl/BAAI/Infinity-Instruct/7M.jsonl
#- datasets/jsonl/BAAI/Infinity-Instruct/7M_domains.jsonl
- datasets/jsonl/BAAI/Infinity-Instruct/0625.jsonl
#- datasets/jsonl/BAAI/Infinity-Instruct/Gen.jsonl
OpenHermes:
ratio: 1.0
data_paths:
- datasets/jsonl/teknium/OpenHermes-2.5/openhermes2_5.jsonl
lima:
ratio: 1.0
data_paths:
- datasets/jsonl/GAIR/lima/train.jsonl
databricks-dolly-15k:
ratio: 1.0
data_paths:
- datasets/jsonl/databricks/databricks-dolly-15k/databricks-dolly-15k.jsonl
MetaMathQA:
ratio: 1.0
data_paths:
- datasets/jsonl/meta-math/MetaMathQA/MetaMathQA-395K.jsonl
MathInstruct:
ratio: 1.0
data_paths:
- datasets/jsonl/TIGER-Lab/MathInstruct/MathInstruct.jsonl
orca-math-word-problems-200k:
ratio: 1.0
data_paths:
- datasets/jsonl/microsoft/orca-math-word-problems-200k/data.jsonl
atlas-math-sets:
ratio: 1.0
num: 100000
data_paths:
- datasets/jsonl/AtlasUnified/atlas-math-sets/train.jsonl
goat:
ratio: 1.0
num: 30000
data_paths:
- datasets/jsonl/tiedong/goat/dataset.jsonl
camel-ai:
ratio: 1.0
data_paths:
- datasets/jsonl/camel-ai/math/math.jsonl
Long-Instruction-with-Paraphrasing:
ratio: 1.0
data_paths:
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_en.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_en_paraphrasing.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_en.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_alpaca_en.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_en_paraphrasing.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/translation_en2zh.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_zh.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_zh_paraphrasing.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_zh.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_llama_chinese.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_zh_paraphrasing.jsonl
Long:
ratio: 1.0
data_paths:
- datasets/jsonl/akoksal/LongForm/data.jsonl
- datasets/jsonl/THUDM/LongAlign-10k/long.jsonl
- datasets/jsonl/THUDM/LongCite-45k/long.jsonl
- datasets/jsonl/THUDM/LongWriter-6k/long.jsonl
- datasets/jsonl/YeungNLP/LongQLoRA-Dataset/LongQLoRA-SFT-Data-39k.jsonl
- datasets/jsonl/Yukang/LongAlpaca-12k/LongAlpaca-12k.jsonl
- datasets/jsonl/togethercomputer/Long-Data-Collections/natural_questions_10_200_docs.jsonl
- datasets/jsonl/togethercomputer/Long-Data-Collections/booksum.jsonl
- datasets/jsonl/KnutJaegersberg/longinstruct/longinstruct.jsonl
open-thoughts/OpenThoughts2-1M:
ratio: 0.0
num: 200000
data_paths:
- datasets/jsonl/open-thoughts/OpenThoughts2-1M/data.jsonl
nvidia/Llama-Nemotron-Post-Training-Dataset:
ratio: 0.0
num: 200000
data_paths:
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_chat.jsonl
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_code.jsonl
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_math.jsonl
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_safety.jsonl
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_science.jsonl
glaiveai/reasoning-v1-20m:
ratio: 0.0
num: 200000
data_paths:
- datasets/jsonl/glaiveai/reasoning-v1-20m/data.jsonl
nvidia/OpenCodeReasoning:
ratio: 0.0
num: 200000
data_paths:
- datasets/jsonl/nvidia/OpenCodeReasoning/split_0.jsonl
Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT:
ratio: 0.0
num: 200000
data_paths:
- datasets/jsonl/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT/data.jsonl
open-r1/OpenR1-Math-220k:
ratio: 0.0
num: 200000
data_paths:
#- datasets/jsonl/open-r1/OpenR1-Math-220k/default.jsonl
- datasets/jsonl/open-r1/OpenR1-Math-220k/all.jsonl
#- datasets/jsonl/open-r1/OpenR1-Math-220k/extended.jsonl
xlsx_sample_num: 5
dataset:
wenet-e2e/wenetspeech:
ratio: 0.05
data_paths:
- datasets/jsonl/wenet-e2e/wenetspeech/L_fixed.jsonl
- datasets/jsonl/wenet-e2e/wenetspeech/DEV_fixed.jsonl
Wenetspeech4TTS/Wenetspeech4TTS:
ratio: 0.05
data_paths:
- datasets/jsonl/Wenetspeech4TTS/WenetSpeech4TTS/Basic.jsonl
fixie-ai/librispeech_asr:
ratio: 0.05
data_paths:
- datasets/jsonl/fixie-ai/librispeech_asr/train.100.clean.jsonl
- datasets/jsonl/fixie-ai/librispeech_asr/train.360.clean.jsonl
- datasets/jsonl/fixie-ai/librispeech_asr/train.500.other.jsonl
mythicinfinity/libritts:
ratio: 0.05
data_paths:
- datasets/jsonl/mythicinfinity/libritts/train.clean.100.jsonl
- datasets/jsonl/mythicinfinity/libritts/train.clean.360.jsonl
- datasets/jsonl/mythicinfinity/libritts/train.other.500.jsonl
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.100.jsonl
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.360.jsonl
- datasets/jsonl/mythicinfinity/libritts_r/train.other.500.jsonl
parler-tts/mls_eng:
ratio: 0.05
data_paths:
#- datasets/jsonl/parler-tts/mls_eng_10k/train.jsonl
- datasets/jsonl/parler-tts/mls_eng/train.jsonl
mozilla-foundation/common_voice_17_0:
ratio: 0.05
data_paths:
- datasets/jsonl/mozilla-foundation/common_voice_17_0/en/train.jsonl
- datasets/jsonl/mozilla-foundation/common_voice_17_0/zh-CN/train.jsonl
MushanW/GLOBE_V2:
ratio: 0.05
data_paths:
- datasets/jsonl/MushanW/GLOBE_V2/train.jsonl
amphion/Emilia-Dataset:
ratio: 0.025
data_paths:
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200.jsonl
amphion/Emilia-Dataset/speaker_prompt:
ratio: 0.025
data_paths:
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100_speak_prompt.jsonl
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200_speak_prompt.jsonl
openslr:
ratio: 0.05
data_paths:
- datasets/jsonl/openslr/SLR68/train.jsonl
- datasets/jsonl/openslr/SLR68/dev.jsonl
speechcolab/gigaspeech:
ratio: 0.05
data_paths:
- datasets/jsonl/speechcolab/gigaspeech/xl.jsonl
- datasets/jsonl/speechcolab/gigaspeech/dev.jsonl
MLCommons/peoples_speech:
ratio: 0.05
data_paths:
- datasets/jsonl/MLCommons/peoples_speech/clean.jsonl
- datasets/jsonl/MLCommons/peoples_speech/clean_sa.jsonl
- datasets/jsonl/MLCommons/peoples_speech/dirty.jsonl
- datasets/jsonl/MLCommons/peoples_speech/dirty_sa.jsonl
- datasets/jsonl/MLCommons/peoples_speech/validation.jsonl
facebook/voxpopuli:
ratio: 0.05
data_paths:
- datasets/jsonl/facebook/voxpopuli/en_train.jsonl
- datasets/jsonl/facebook/voxpopuli/en_accented_test.jsonl
shenyunhang:
ratio: 0.05
data_paths:
- datasets/jsonl/shenyunhang/AISHELL-1/train.jsonl
- datasets/jsonl/shenyunhang/AISHELL-1/dev.jsonl
- datasets/jsonl/shenyunhang/AISHELL-2/data.jsonl
- datasets/jsonl/shenyunhang/AISHELL-3/data.jsonl
- datasets/jsonl/shenyunhang/AISHELL-4/data.jsonl
gpt-omni/VoiceAssistant-400K:
ratio: 2.0
data_paths:
- datasets/jsonl/gpt-omni/VoiceAssistant-400K/data.jsonl
VITA-MLLM/AudioQA-1M:
ratio: 2.0
data_paths:
- datasets/jsonl/VITA-MLLM/AudioQA-1M/data.jsonl
BAAI/Infinity-Instruct:
ratio: 0.05
data_paths:
#- datasets/jsonl/BAAI/Infinity-Instruct/3M.jsonl
#- datasets/jsonl/BAAI/Infinity-Instruct/7M.jsonl
#- datasets/jsonl/BAAI/Infinity-Instruct/7M_domains.jsonl
- datasets/jsonl/BAAI/Infinity-Instruct/0625.jsonl
#- datasets/jsonl/BAAI/Infinity-Instruct/Gen.jsonl
OpenHermes:
ratio: 0.05
data_paths:
- datasets/jsonl/teknium/OpenHermes-2.5/openhermes2_5.jsonl
lima:
ratio: 0.05
data_paths:
- datasets/jsonl/GAIR/lima/train.jsonl
databricks-dolly-15k:
ratio: 0.05
data_paths:
- datasets/jsonl/databricks/databricks-dolly-15k/databricks-dolly-15k.jsonl
MetaMathQA:
ratio: 0.05
data_paths:
- datasets/jsonl/meta-math/MetaMathQA/MetaMathQA-395K.jsonl
MathInstruct:
ratio: 0.05
data_paths:
- datasets/jsonl/TIGER-Lab/MathInstruct/MathInstruct.jsonl
orca-math-word-problems-200k:
ratio: 0.05
data_paths:
- datasets/jsonl/microsoft/orca-math-word-problems-200k/data.jsonl
atlas-math-sets:
ratio: 0.05
num: 100000
data_paths:
- datasets/jsonl/AtlasUnified/atlas-math-sets/train.jsonl
goat:
ratio: 0.05
num: 30000
data_paths:
- datasets/jsonl/tiedong/goat/dataset.jsonl
camel-ai:
ratio: 0.05
data_paths:
- datasets/jsonl/camel-ai/math/math.jsonl
Long-Instruction-with-Paraphrasing:
ratio: 0.05
data_paths:
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_en.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_en_paraphrasing.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_en.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_alpaca_en.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_en_paraphrasing.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/translation_en2zh.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_zh.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_zh_paraphrasing.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_zh.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_llama_chinese.jsonl
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_zh_paraphrasing.jsonl
Long:
ratio: 0.05
data_paths:
- datasets/jsonl/akoksal/LongForm/data.jsonl
- datasets/jsonl/THUDM/LongAlign-10k/long.jsonl
- datasets/jsonl/THUDM/LongCite-45k/long.jsonl
- datasets/jsonl/THUDM/LongWriter-6k/long.jsonl
- datasets/jsonl/YeungNLP/LongQLoRA-Dataset/LongQLoRA-SFT-Data-39k.jsonl
- datasets/jsonl/Yukang/LongAlpaca-12k/LongAlpaca-12k.jsonl
- datasets/jsonl/togethercomputer/Long-Data-Collections/natural_questions_10_200_docs.jsonl
- datasets/jsonl/togethercomputer/Long-Data-Collections/booksum.jsonl
- datasets/jsonl/KnutJaegersberg/longinstruct/longinstruct.jsonl
open-thoughts/OpenThoughts2-1M:
ratio: 0.0
num: 10000
data_paths:
- datasets/jsonl/open-thoughts/OpenThoughts2-1M/data.jsonl
nvidia/Llama-Nemotron-Post-Training-Dataset:
ratio: 0.0
num: 10000
data_paths:
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_chat.jsonl
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_code.jsonl
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_math.jsonl
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_safety.jsonl
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_science.jsonl
glaiveai/reasoning-v1-20m:
ratio: 0.0
num: 10000
data_paths:
- datasets/jsonl/glaiveai/reasoning-v1-20m/data.jsonl
nvidia/OpenCodeReasoning:
ratio: 0.0
num: 10000
data_paths:
- datasets/jsonl/nvidia/OpenCodeReasoning/split_0.jsonl
Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT:
ratio: 0.0
num: 10000
data_paths:
- datasets/jsonl/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT/data.jsonl
open-r1/OpenR1-Math-220k:
ratio: 0.0
num: 10000
data_paths:
#- datasets/jsonl/open-r1/OpenR1-Math-220k/default.jsonl
- datasets/jsonl/open-r1/OpenR1-Math-220k/all.jsonl
#- datasets/jsonl/open-r1/OpenR1-Math-220k/extended.jsonl
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10-fixpy
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
# RUN source /opt/dtk-dtk25.04/env.sh
# # 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
docker run -it --shm-size=64G -v $PWD/VITA-Audio:/home/VITA-Audio -v /public/DL_DATA/AI:/home/AI -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name va 6063b673703a bash
# python -m torch.utils.collect_env
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import re
import string
import sys
import unicodedata
from word2number import w2n
# def is_list_in_string(text, candidate):
# return any([all(xx in text for xx in x.split(" ")) if isinstance(x, str) else all([xx in text for xx in x]) for x in candidate])
def is_string_in_string(text, candidate):
return all(x in text for x in candidate.split(" "))
def is_list_in_string(text, candidate):
return any(
[
is_string_in_string(text, x) if isinstance(x, str) else is_list_in_string(text, x)
for x in candidate
]
)
def clean_punctuation(value):
punctuation = string.punctuation
punctuation = punctuation.replace("'", "")
value = re.sub(f"[{punctuation}]", " ", value)
return value
if __name__ == "__main__":
pred_gt_json_file = sys.argv[1]
with open(pred_gt_json_file, "r") as f:
pred_gt = json.load(f)
acc = 0
for line in pred_gt:
pred = line[0]
gt = line[1]
# pred = clean_punctuation(pred)
pred = pred.lower()
if isinstance(gt, list):
pass
else:
gt = [
gt,
]
gt = [clean_punctuation(x) for x in gt]
gt = [x.lower().strip() for x in gt]
try:
gt_number = [str(w2n.word_to_num(x.lower())) for x in gt]
except:
gt_number = gt
pass
if is_list_in_string(pred, gt):
acc += 1
elif is_list_in_string(pred, gt_number):
acc += 1
else:
print("======================================================")
print(f"{line[0]=}")
print(f"{line[1]=}")
print("======================================================")
print(f"{acc=}")
print(f"{len(pred_gt)=}")
print("======================================================")
acc = acc / len(pred_gt) * 100
print("======================================================")
print(f"{acc=}")
print("======================================================")
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import unicodedata
import codecs
remove_tag = True
spacelist = [' ', '\t', '\r', '\n']
puncts = [
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
'《', '》'
]
def characterize(string):
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
# https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<':
sep = '>'
j = i + 1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c == sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x:
return ''
chars = []
i = 0
T = len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
if x.isalnum():
for k in x:
new_sentence.append(k)
else:
new_sentence.append(x)
return new_sentence
class Calculator:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print('this should not happen , i={i} , j={j} , \
error={error}'.format(i=i,
j=j,
error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith('DIGIT'): # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
or unicode_names[i].startswith('LATIN SMALL LETTER')):
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND')
or unicode_names[i].startswith('APOSTROPHE')
or unicode_names[i].startswith('COMMERCIAL AT')
or unicode_names[i].startswith('DEGREE CELSIUS')
or unicode_names[i].startswith('EQUALS SIGN')
or unicode_names[i].startswith('FULL STOP')
or unicode_names[i].startswith('HYPHEN-MINUS')
or unicode_names[i].startswith('LOW LINE')
or unicode_names[i].startswith('NUMBER SIGN')
or unicode_names[i].startswith('PLUS SIGN')
or unicode_names[i].startswith('SEMICOLON')):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else:
return 'Other'
if len(unicode_names) == 0:
return 'Other'
if len(unicode_names) == 1:
return unicode_names[0]
for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i + 1]:
return 'Other'
return unicode_names[0]
def usage():
print("compute-wer.py : compute word error rate (WER) \
and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] \
[--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] \
[--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
if __name__ == '__main__':
if len(sys.argv) == 1:
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose = 1
padding_symbol = ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose = 0
try:
verbose = int(b)
except Exception:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol = ' '
elif b == 'underline':
padding_symbol = '_'
continue
if True or sys.argv[1].startswith('-'):
# ignore invalid switch
del sys.argv[1]
continue
if not case_sensitive:
ig = set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0:
continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8'):
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0:
continue
fid = array[0]
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab:
if word not in default_words:
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters:
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name]:
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('WER: %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])):
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
else:
print('lab:', end=' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
else:
print('rec:', end=' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
if verbose:
print('==================================================='
'========================')
print()
result = calculator.overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('Overall -> %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters:
result = calculator.cluster(k
for k in default_clusters[cluster_id])
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8'):
for token in line.decode('utf-8').rstrip('\n').split():
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token) - 1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif (token[0] == '<' and token[len(token) - 1] == '>'
and cluster_id == ''):
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else:
cluster.append(token)
print()
print('======================================='
'====================================')
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re, sys, unicodedata
import codecs
remove_tag = True
spacelist = [' ', '\t', '\r', '\n']
puncts = [
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
'《', '》'
]
def characterize(string):
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i + 1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c == sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0
T = len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
else:
new_sentence.append(x)
return new_sentence
class Calculator:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith('DIGIT'): # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
or unicode_names[i].startswith('LATIN SMALL LETTER')):
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND')
or unicode_names[i].startswith('APOSTROPHE')
or unicode_names[i].startswith('COMMERCIAL AT')
or unicode_names[i].startswith('DEGREE CELSIUS')
or unicode_names[i].startswith('EQUALS SIGN')
or unicode_names[i].startswith('FULL STOP')
or unicode_names[i].startswith('HYPHEN-MINUS')
or unicode_names[i].startswith('LOW LINE')
or unicode_names[i].startswith('NUMBER SIGN')
or unicode_names[i].startswith('PLUS SIGN')
or unicode_names[i].startswith('SEMICOLON')):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else:
return 'Other'
if len(unicode_names) == 0:
return 'Other'
if len(unicode_names) == 1:
return unicode_names[0]
for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i + 1]:
return 'Other'
return unicode_names[0]
def usage():
print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if __name__ == '__main__':
if len(sys.argv) == 1:
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose = 1
padding_symbol = ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose = 0
try:
verbose = int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol = ' '
elif b == 'underline':
padding_symbol = '_'
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
del sys.argv[1]
continue
if not case_sensitive:
ig = set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8'):
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0: continue
fid = array[0]
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab:
if word not in default_words:
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters:
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name]:
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('WER: %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])):
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
else:
print('lab:', end=' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
else:
print('rec:', end=' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
if verbose:
print(
'==========================================================================='
)
print()
result = calculator.overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('Overall -> %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters:
result = calculator.cluster(
[k for k in default_clusters[cluster_id]])
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8'):
for token in line.decode('utf-8').rstrip('\n').split():
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else:
cluster.append(token)
print()
print(
'==========================================================================='
)
import argparse
import itertools
import json
import os
import random
import sys
import uuid
from datetime import timedelta
from functools import partial
from pathlib import Path
import torch
import tqdm
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torchaudio
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
from vita_audio.tokenizer import get_audio_tokenizer
def collate_fn(batches):
input_ids = [sample["input_ids"] for sample in batches]
audios = [sample["audios"] for sample in batches]
audio_indices = [sample["audio_indices"] for sample in batches]
refs = [sample["ref"] for sample in batches]
return input_ids, audios, audio_indices, refs
class ASRDataset(torch.utils.data.Dataset):
def __init__(
self,
json_path,
tokenizer,
audio_tokenizer,
default_system_message=None,
add_generation_prompt=True,
):
data = load_dataset("json", data_files=json_path, keep_in_memory=False)
self.data = data["train"]
self.tokenizer = tokenizer
self.add_generation_prompt = add_generation_prompt
self.audio_tokenizer = audio_tokenizer
self.default_system_message = default_system_message
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# print(f"sample {sample}")
audio_path = sample["audios"][0]
if self.audio_tokenizer.apply_to_role("user", is_discrete=True):
# discrete codec
audio_tokens = self.audio_tokenizer.encode(audio_path)
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
else:
audio_tokens = None
messages = []
if len(sample["messages"]) == 2:
assert len(sample["messages"]) == 2
assert sample["messages"][0]["role"] == "user"
assert sample["messages"][1]["role"] == "assistant"
if self.default_system_message is not None:
messages = self.default_system_message + messages
elif len(sample["messages"]) == 3:
assert len(sample["messages"]) == 3
assert sample["messages"][0]["role"] == "system"
assert sample["messages"][1]["role"] == "user"
assert sample["messages"][2]["role"] == "assistant"
else:
raise NotImplementedError
# print(sample)
for conv in sample["messages"][:-1]:
new_conv = {}
new_conv["role"] = conv["role"]
content = conv["content"]
if audio_tokens is not None:
content = content.replace(
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
)
new_conv["content"] = content
messages.append(new_conv)
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=self.add_generation_prompt,
# return_tensors="pt",
)
ref = sample["messages"][-1]["content"]
if self.audio_tokenizer.apply_to_role("user", is_contiguous=True):
# contiguous codec
input_ids, audios, audio_indices = add_audio_input_contiguous(
input_ids, [audio_path], self.tokenizer, self.audio_tokenizer
)
else:
audios = None
audio_indices = None
input_ids = torch.tensor([input_ids], dtype=torch.long)
return {
"input_ids": input_ids,
"audios": audios,
"audio_indices": audio_indices,
"ref": ref,
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir):
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
outputs = []
for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref) in enumerate(
tqdm.tqdm(dataloader)
):
for input_ids, audios, audio_indices, ref in zip(
batched_input_ids, batched_audios, batched_audio_indices, batched_ref
):
kwargs = {
# "temperature": 0.2,
# "top_p": 0.8,
# "do_sample": False,
# "temperature": 1.0,
"max_new_tokens": max([len(x) for x in batched_ref]) + 10,
"min_new_tokens": 1,
}
if audios is not None:
kwargs["audios"] = audios
kwargs["audio_indices"] = audio_indices
responses = model.generate(
input_ids=input_ids.cuda(),
**kwargs,
)
response = responses[0][len(input_ids[0]) :]
text_tokens = []
audio_tokens = []
for token_id in response:
if token_id >= audio_offset:
audio_tokens.append(token_id - audio_offset)
else:
text_tokens.append(token_id)
hyp = tokenizer.decode(text_tokens, skip_special_tokens=True)
outputs.append((hyp, ref))
print("")
print("=" * 100)
print(f"{hyp=}")
print(f"{ref=}")
return outputs
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
parser.add_argument(
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
)
parser.add_argument(
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
)
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
parser.add_argument("--json_path", type=str, required=True, help="json_path")
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=0)
args = parser.parse_args()
print(f"{args=}")
torch.distributed.init_process_group(
backend="nccl",
world_size=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
timeout=timedelta(seconds=7200),
)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
random.seed(42)
torch.manual_seed(42)
config = AutoConfig.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
# ================================================================
if "glm" in config.model_type.lower():
from get_chat_template import glm4_chat_template as chat_template
add_generation_prompt = True
default_system_message = [
{
"role": "system",
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
}
]
if "qwen2" in config.model_type.lower():
from get_chat_template import qwen2_chat_template as chat_template
add_generation_prompt = True
default_system_message = []
if "hunyuan" in config.model_type.lower():
from get_chat_template import hunyuan_chat_template as chat_template
add_generation_prompt = False
default_system_message = [
{
"role": "system",
"content": "You are a helpful AI assistant.",
}
]
# ================================================================
print("Loading model")
# device_map = "auto"
device_map = "cuda"
# torch_dtype=torch.float16
torch_dtype = torch.bfloat16
rank = torch.distributed.get_rank()
audio_tokenizer = get_audio_tokenizer(
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
chat_template=chat_template,
)
# print("tokenizer", tokenizer)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
).eval()
# print("model", model)
model.generation_config = GenerationConfig.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
model.generation_config.max_new_tokens = 4096
model.generation_config.chat_format = "chatml"
model.generation_config.max_window_size = 8192
model.generation_config.use_cache = True
model.generation_config.do_sample = False
model.generation_config.pad_token_id = tokenizer.pad_token_id
if model.config.model_type == "hunyuan":
model.generation_config.eos_token_id = tokenizer.eos_id
# ================================================================
print("Loading data")
dataset = ASRDataset(
json_path=args.json_path,
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
default_system_message=default_system_message,
add_generation_prompt=add_generation_prompt,
)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(
collate_fn,
),
)
# ================================================================
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
merged_outputs = [json.loads(_) for _ in merged_outputs]
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
if torch.distributed.get_rank() == 0:
# json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt")
ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
os.makedirs(os.path.dirname(hyp_path), exist_ok=True)
hyp_file = open(hyp_path, "w")
ref_file = open(ref_path, "w")
for sample_idx, (hyp, ref) in enumerate(merged_outputs):
hyp_file.write(f"{sample_idx} {hyp}" + "\n")
ref_file.write(f"{sample_idx} {ref}" + "\n")
hyp_file.close()
ref_file.close()
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json")
hyp_ref_file = open(hyp_ref_path, "w")
json.dump(merged_outputs, hyp_ref_file, indent=4)
hyp_ref_file.close()
torch.distributed.barrier()
print("Done.")
import argparse
import itertools
import json
import os
import random
import re
import sys
import uuid
from datetime import timedelta
from functools import partial
from pathlib import Path
import torch
import tqdm
from datasets import load_dataset
from tn.english.normalizer import Normalizer as EnNormalizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torchaudio
from vita_audio.tokenizer import get_audio_tokenizer
def collate_fn(batches):
input_ids = [sample["input_ids"] for sample in batches]
refs = [sample["ref"] for sample in batches]
filenames = [sample["filename"] for sample in batches]
return input_ids, refs, filenames
class TTSDataset(torch.utils.data.Dataset):
def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True):
data = load_dataset("json", data_files=json_path, keep_in_memory=False)
self.data = data["train"]
self.tokenizer = tokenizer
self.audio_tokenizer = audio_tokenizer
self.default_system_message = default_system_message
self.add_generation_prompt = add_generation_prompt
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
messages = []
if self.default_system_message is not None:
messages = self.default_system_message + messages
role = "user"
content = sample["messages"][0]["content"]
messages.append(
{
"role": role,
"content": content,
}
)
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=self.add_generation_prompt,
return_tensors="pt",
)
ref = sample["messages"][0]["content"]
ref = ref.replace("Convert the text to speech.\n", "")
ref = ref.strip()
filepath = sample["audios"][0]
filename = os.path.basename(filepath)
return {
"input_ids": input_ids,
"ref": ref,
"filename": filename,
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model):
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
en_tn_model = EnNormalizer(overwrite_cache=True)
outputs = []
for _, (
batched_input_ids,
batched_ref,
batched_filename,
) in enumerate(tqdm.tqdm(dataloader)):
for input_ids, ref, filename in zip(
batched_input_ids, batched_ref, batched_filename
):
responses = model.generate(
input_ids=input_ids.cuda(),
# temperature=0.2,
# top_p=0.8,
# do_sample=False,
# temperature=1.0,
max_new_tokens=1024,
min_new_tokens=1,
)
response = responses[0][len(input_ids[0]) :]
text_tokens = []
audio_tokens = []
for token_id in response:
if token_id >= audio_offset:
audio_tokens.append(token_id - audio_offset)
else:
text_tokens.append(token_id)
if len(audio_tokens) == 0:
continue
tts_speech = audio_tokenizer.decode(audio_tokens)
wav_dir = os.path.join(output_dir, "audio")
wav_path = os.path.join(wav_dir, filename + ".wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
hyp = asr_model(wav_path, return_timestamps=True)["text"].strip()
hyp = en_tn_model.normalize(hyp)
ref = en_tn_model.normalize(ref)
hyp = re.sub(r"\W+", " ", hyp)
ref = re.sub(r"\W+", " ", ref)
outputs.append((hyp, ref))
print("")
print("=" * 100)
# print(f"{len(input_id)=}")
# print(f"{len(response)=}")
print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
print(f"{filename=}")
return outputs
def load_asr_model():
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
rank = torch.distributed.get_rank()
device = f"cuda:{rank}"
torch_dtype = torch.float16
model_id = "/data/models/openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
return pipe
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
parser.add_argument(
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
)
parser.add_argument(
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
)
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
parser.add_argument("--json_path", type=str, required=True, help="json_path")
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False)
args = parser.parse_args()
print(f"{args=}")
torch.distributed.init_process_group(
backend="nccl",
world_size=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
timeout=timedelta(seconds=7200),
)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
random.seed(42)
torch.manual_seed(42)
config = AutoConfig.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
# ================================================================
if "glm" in config.model_type.lower():
from get_chat_template import glm4_chat_template as chat_template
add_generation_prompt = True
default_system_message = [
{
"role": "system",
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
}
]
if "qwen2" in config.model_type.lower():
from get_chat_template import qwen2_chat_template as chat_template
add_generation_prompt = True
default_system_message = []
if "hunyuan" in config.model_type.lower():
from get_chat_template import hunyuan_chat_template as chat_template
add_generation_prompt = False
default_system_message = [
{
"role": "system",
"content": "You are a helpful AI assistant.",
}
]
# ================================================================
print("Loading model")
device = "cuda"
# device_map = "auto"
device_map = "cuda"
# torch_dtype=torch.float16
torch_dtype = torch.bfloat16
rank = torch.distributed.get_rank()
audio_tokenizer = get_audio_tokenizer(
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
chat_template=chat_template,
)
# print("tokenizer", tokenizer)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
).eval()
# print("model", model)
model.generation_config = GenerationConfig.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
model.generation_config.max_new_tokens = 4096
model.generation_config.chat_format = "chatml"
model.generation_config.max_window_size = 8192
model.generation_config.use_cache = True
model.generation_config.do_sample = True
model.generation_config.pad_token_id = tokenizer.pad_token_id
if model.config.model_type == "hunyuan":
model.generation_config.eos_token_id = tokenizer.eos_id
asr_model = load_asr_model()
# ================================================================
print("Loading data")
dataset = TTSDataset(
json_path=args.json_path,
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
default_system_message=default_system_message,
add_generation_prompt=add_generation_prompt,
)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(
collate_fn,
),
)
# ================================================================
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
merged_outputs = [json.loads(_) for _ in merged_outputs]
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
if torch.distributed.get_rank() == 0:
# json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt")
ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
os.makedirs(os.path.dirname(hyp_path), exist_ok=True)
hyp_file = open(hyp_path, "w")
ref_file = open(ref_path, "w")
for sample_idx, (hyp, ref) in enumerate(merged_outputs):
hyp_file.write(f"{sample_idx} {hyp}" + "\n")
ref_file.write(f"{sample_idx} {ref}" + "\n")
hyp_file.close()
ref_file.close()
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json")
hyp_ref_file = open(hyp_ref_path, "w")
json.dump(merged_outputs, hyp_ref_file, indent=4)
hyp_ref_file.close()
torch.distributed.barrier()
print("Done.")
import argparse
import itertools
import json
import os
import random
import sys
import uuid
from datetime import timedelta
from functools import partial
from pathlib import Path
import torch
import tqdm
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torchaudio
from vita_audio.tokenizer import get_audio_tokenizer
def collate_fn(batches):
input_ids = [sample["input_ids"] for sample in batches]
refs = [sample["ref"] for sample in batches]
filenames = [sample["filename"] for sample in batches]
prompt_audio_path = [sample["prompt_audio_path"] for sample in batches]
return input_ids, refs, filenames, prompt_audio_path
class SeedTTSDataset(torch.utils.data.Dataset):
def __init__(
self,
data_path,
tokenizer,
audio_tokenizer,
default_system_message=None,
speaker_prompt=False,
add_generation_prompt=True,
):
self.data = []
meta_path = os.path.join(data_path, f"seedtts_testset/zh/meta.lst")
with open(meta_path, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split("|")
filename = line[0]
prompt_text = line[1]
prompt_audio = line[2]
text = line[3]
self.data.append(["zh", filename, prompt_text, prompt_audio, text])
meta_path = os.path.join(data_path, f"seedtts_testset/zh/hardcase.lst")
with open(meta_path, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split("|")
filename = line[0]
prompt_text = line[1]
prompt_audio = line[2]
text = line[3]
self.data.append(["hardcase", filename, prompt_text, prompt_audio, text])
meta_path = os.path.join(data_path, f"seedtts_testset/en/meta.lst")
with open(meta_path, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split("|")
filename = line[0]
prompt_text = line[1]
prompt_audio = line[2]
text = line[3]
self.data.append(["en", filename, prompt_text, prompt_audio, text])
self.tokenizer = tokenizer
self.audio_tokenizer = audio_tokenizer
self.default_system_message = default_system_message
self.add_generation_prompt = add_generation_prompt
self.data_path = data_path
self.speaker_prompt = speaker_prompt
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
split, filename, prompt_text, prompt_audio, text = sample
messages = []
if self.default_system_message is not None:
messages = self.default_system_message + messages
if self.speaker_prompt:
if split == "hardcase":
prompt_audio_path = os.path.join(
self.data_path, "seedtts_testset", "zh", prompt_audio
)
else:
prompt_audio_path = os.path.join(
self.data_path, "seedtts_testset", split, prompt_audio
)
if self.audio_tokenizer.apply_to_role("system", is_discrete=True):
# discrete codec
prompt_audio_tokens = self.audio_tokenizer.encode(prompt_audio_path)
prompt_audio_tokens = "".join(f"<|audio_{i}|>" for i in prompt_audio_tokens)
prompt_text = f"Speaker Metadata:\nAudio: <|begin_of_audio|>{prompt_audio_tokens}<|end_of_audio|>\n"
if len(messages) > 0 and messages[0]["role"] == "system":
messages[0]["content"] += prompt_text
else:
messages.append(
{
"role": "system",
"content": prompt_text,
}
)
else:
prompt_audio_path = None
role = "user"
content = "Convert the text to speech.\n" + text
messages.append(
{
"role": role,
"content": content,
}
)
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=self.add_generation_prompt,
return_tensors="pt",
)
ref = text
return {
"input_ids": input_ids,
"ref": ref,
"filename": split + "/" + filename,
"prompt_audio_path": prompt_audio_path,
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir):
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
outputs = []
for _, (
batched_input_ids,
batched_ref,
batched_filename,
batched_prompt_audio_path,
) in enumerate(tqdm.tqdm(dataloader)):
for input_ids, ref, filename, prompt_audio_path in zip(
batched_input_ids, batched_ref, batched_filename, batched_prompt_audio_path
):
responses = model.generate(
input_ids=input_ids.cuda(),
# temperature=0.2,
# top_p=0.8,
# do_sample=False,
# temperature=1.0,
max_new_tokens=1024,
min_new_tokens=1,
)
response = responses[0][len(input_ids[0]) :]
text_tokens = []
audio_tokens = []
for token_id in response:
if token_id >= audio_offset:
audio_tokens.append(token_id - audio_offset)
else:
text_tokens.append(token_id)
if len(audio_tokens) == 0:
continue
tts_speech = audio_tokenizer.decode(audio_tokens, source_speech_16k=prompt_audio_path)
wav_path = os.path.join(output_dir, filename + ".wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
outputs.append((wav_path, filename))
print("")
print("=" * 100)
# print(f"{len(input_id)=}")
# print(f"{len(response)=}")
print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
print(f"{filename=}")
return outputs
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
parser.add_argument(
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
)
parser.add_argument(
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
)
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
parser.add_argument("--data_path", type=str, required=True, help="data_path")
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False)
args = parser.parse_args()
print(f"{args=}")
torch.distributed.init_process_group(
backend="nccl",
world_size=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
timeout=timedelta(seconds=7200),
)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
random.seed(42)
torch.manual_seed(42)
config = AutoConfig.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
# ================================================================
if "glm" in config.model_type.lower():
from get_chat_template import glm4_chat_template as chat_template
add_generation_prompt = True
default_system_message = [
{
"role": "system",
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
}
]
if "qwen2" in config.model_type.lower():
from get_chat_template import qwen2_chat_template as chat_template
add_generation_prompt = True
default_system_message = []
if "hunyuan" in config.model_type.lower():
from get_chat_template import hunyuan_chat_template as chat_template
add_generation_prompt = False
default_system_message = [
{
"role": "system",
"content": "You are a helpful AI assistant.",
}
]
# ================================================================
print("Loading model")
device = "cuda"
# device_map = "auto"
device_map = "cuda"
# torch_dtype=torch.float16
torch_dtype = torch.bfloat16
rank = torch.distributed.get_rank()
audio_tokenizer = get_audio_tokenizer(
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
chat_template=chat_template,
)
# print("tokenizer", tokenizer)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
).eval()
# print("model", model)
model.generation_config = GenerationConfig.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
model.generation_config.max_new_tokens = 4096
model.generation_config.chat_format = "chatml"
model.generation_config.max_window_size = 8192
model.generation_config.use_cache = True
model.generation_config.do_sample = True
model.generation_config.pad_token_id = tokenizer.pad_token_id
if model.config.model_type == "hunyuan":
model.generation_config.eos_token_id = tokenizer.eos_id
# ================================================================
print("Loading data")
dataset = SeedTTSDataset(
data_path=args.data_path,
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
default_system_message=default_system_message,
speaker_prompt=args.speaker_prompt,
add_generation_prompt=add_generation_prompt,
)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(
collate_fn,
),
)
# ================================================================
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
merged_outputs = [json.loads(_) for _ in merged_outputs]
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
torch.distributed.barrier()
print("Done.")
import argparse
import itertools
import json
import os
import random
import sys
import uuid
from datetime import timedelta
from functools import partial
from pathlib import Path
import torch
import tqdm
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torchaudio
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
from vita_audio.tokenizer import get_audio_tokenizer
def collate_fn(batches):
input_ids = [sample["input_ids"] for sample in batches]
audios = [sample["audios"] for sample in batches]
audio_indices = [sample["audio_indices"] for sample in batches]
refs = [sample["ref"] for sample in batches]
filenames = [sample["filename"] for sample in batches]
return input_ids, audios, audio_indices, refs, filenames
class STSDataset(torch.utils.data.Dataset):
def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True):
data = load_dataset("json", data_files=json_path, keep_in_memory=False)
self.data = data["train"]
self.tokenizer = tokenizer
self.add_generation_prompt = add_generation_prompt
self.audio_tokenizer = audio_tokenizer
self.default_system_message = default_system_message
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
assert len(sample["audios"]) == 1
audio_path = sample["audios"][0]
if self.audio_tokenizer.apply_to_role("user", is_discrete=True):
# discrete codec
audio_tokens = self.audio_tokenizer.encode(audio_path)
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
else:
audio_tokens = None
messages = []
if len(sample["messages"]) == 2:
assert len(sample["messages"]) == 2
assert sample["messages"][0]["role"] == "user"
assert sample["messages"][1]["role"] == "assistant"
if self.default_system_message is not None:
messages = self.default_system_message + messages
elif len(sample["messages"]) == 3:
assert len(sample["messages"]) == 3
assert sample["messages"][0]["role"] == "system"
assert sample["messages"][1]["role"] == "user"
assert sample["messages"][2]["role"] == "assistant"
else:
raise NotImplementedError
for conv in sample["messages"][:-1]:
new_conv = {}
new_conv["role"] = conv["role"]
content = conv["content"]
if isinstance(content, list):
assert len(content) == 1
content = content[0]
if audio_tokens is not None:
content = content.replace(
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
)
new_conv["content"] = content
messages.append(new_conv)
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=self.add_generation_prompt,
# return_tensors="pt",
)
ref = sample["messages"][-1]["content"]
if self.audio_tokenizer.apply_to_role("user", is_contiguous=True):
# contiguous codec
input_ids, audios, audio_indices = add_audio_input_contiguous(
input_ids, [audio_path], self.tokenizer, self.audio_tokenizer
)
else:
audios = None
audio_indices = None
input_ids = torch.tensor([input_ids], dtype=torch.long)
filename = os.path.basename(audio_path)
filename = os.path.splitext(filename)[0]
return {
"input_ids": input_ids,
"audios": audios,
"audio_indices": audio_indices,
"ref": ref,
"filename": filename,
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model):
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
outputs = []
for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename) in enumerate(
tqdm.tqdm(dataloader)
):
for input_ids, audios, audio_indices, ref, filename in zip(
batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename
):
responses = model.generate(
input_ids=input_ids.cuda(),
audios=audios,
audio_indices=audio_indices,
# temperature=0.2,
# top_p=0.8,
# do_sample=False,
# temperature=1.0,
max_new_tokens=1024,
min_new_tokens=1,
)
response = responses[0][len(input_ids[0]) :]
text_tokens = []
audio_tokens = []
for token_id in response:
if token_id >= audio_offset:
audio_tokens.append(token_id - audio_offset)
else:
text_tokens.append(token_id)
hyp_text = tokenizer.decode(text_tokens, skip_special_tokens=True)
if len(audio_tokens) == 0:
continue
tts_speech = audio_tokenizer.decode(audio_tokens)
wav_dir = os.path.join(output_dir, "audio")
wav_path = os.path.join(wav_dir, filename + ".wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
# hyp_speech = asr_model.transcribe(wav_path)["text"].strip()
hyp_speech = asr_model(wav_path, return_timestamps=True)["text"].strip()
# hyp_speech = ""
outputs.append((hyp_text, hyp_speech, ref))
print("")
print("=" * 100)
print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
print(f" {hyp_text=}")
print(f"{hyp_speech=}")
print(f" {ref=}")
print(f"{filename=}")
return outputs
def load_asr_model():
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
rank = torch.distributed.get_rank()
device = f"cuda:{rank}"
torch_dtype = torch.float16
model_id = "/data/models/openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
return pipe
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
parser.add_argument(
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
)
parser.add_argument(
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
)
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
parser.add_argument("--json_path", type=str, required=True, help="json_path")
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=0)
args = parser.parse_args()
print(f"{args=}")
torch.distributed.init_process_group(
backend="nccl",
world_size=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
timeout=timedelta(seconds=7200),
)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
random.seed(42)
torch.manual_seed(42)
config = AutoConfig.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
)
# ================================================================
if "glm" in config.model_type.lower():
from get_chat_template import glm4_chat_template as chat_template
add_generation_prompt = True
default_system_message = [
{
"role": "system",
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
}
]
if "qwen2" in config.model_type.lower():
from get_chat_template import qwen2_chat_template as chat_template
add_generation_prompt = True
default_system_message = []
if "hunyuan" in config.model_type.lower():
from get_chat_template import hunyuan_chat_template as chat_template
add_generation_prompt = False
default_system_message = [
{
"role": "system",
"content": "You are a helpful AI assistant.",
}
]
default_system_message = [
{
"role": "system",
# "content": "Your Name: Luke\nYour Gender: male\nRespond in a text-audio interleaved manner.",
# "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.",
"content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.",
},
]
# ================================================================
print("Loading model")
device = "cuda"
# device_map = "auto"
device_map = "cuda"
# torch_dtype=torch.float16
torch_dtype = torch.bfloat16
rank = torch.distributed.get_rank()
audio_tokenizer = get_audio_tokenizer(
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
chat_template=chat_template,
)
# print("tokenizer", tokenizer)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
).eval()
# print("model", model)
model.generation_config = GenerationConfig.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
model.generation_config.max_new_tokens = 4096
model.generation_config.chat_format = "chatml"
model.generation_config.max_window_size = 8192
model.generation_config.use_cache = True
model.generation_config.do_sample = False
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None
model.generation_config.pad_token_id = tokenizer.pad_token_id
if model.config.model_type == "hunyuan":
model.generation_config.eos_token_id = tokenizer.eos_id
asr_model = load_asr_model()
# ================================================================
print("Loading data")
dataset = STSDataset(
json_path=args.json_path,
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
default_system_message=default_system_message,
add_generation_prompt=add_generation_prompt,
)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(
collate_fn,
),
)
# ================================================================
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
merged_outputs = [json.loads(_) for _ in merged_outputs]
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
if torch.distributed.get_rank() == 0:
# json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
hyp_text_path = os.path.join(args.output_dir, f"{json_name}_hyp_text.txt")
hyp_speech_path = os.path.join(args.output_dir, f"{json_name}_hyp_speech.txt")
ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
os.makedirs(os.path.dirname(hyp_text_path), exist_ok=True)
os.makedirs(os.path.dirname(hyp_speech_path), exist_ok=True)
hyp_text_file = open(hyp_text_path, "w")
hyp_speech_file = open(hyp_speech_path, "w")
ref_file = open(ref_path, "w")
for sample_idx, (hyp_text, hyp_speech, ref) in enumerate(merged_outputs):
hyp_text_file.write(f"{sample_idx} {hyp_text}" + "\n")
hyp_speech_file.write(f"{sample_idx} {hyp_speech}" + "\n")
ref_file.write(f"{sample_idx} {ref}" + "\n")
hyp_text_file.close()
hyp_speech_file.close()
ref_file.close()
outputs_speech = [[x[1], x[2]] for x in merged_outputs]
outputs_text = [[x[0], x[2]] for x in merged_outputs]
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_text.json")
hyp_ref_file = open(hyp_ref_path, "w")
json.dump(outputs_text, hyp_ref_file, indent=4)
hyp_ref_file.close()
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_speech.json")
hyp_ref_file = open(hyp_ref_path, "w")
json.dump(outputs_speech, hyp_ref_file, indent=4)
hyp_ref_file.close()
torch.distributed.barrier()
print("Done.")
qwen2_chat_template = """
{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n
"""
qwen3_chat_template = """
"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
"""
hunyuan_chat_template = """
{% set context = {'has_head': true} %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = message['content'] %}{% if loop.index0 == 0 %}{% if content == '' %}{% set _ = context.update({'has_head': false}) %}{% else %}{% set content = '<|startoftext|>' + content + '<|extra_4|>' %}{% endif %}{% endif %}{% if message['role'] == 'user' %}{% if loop.index0 == 1 and not context.has_head %}{% set content = '<|startoftext|>' + content %}{% endif %}{% if loop.index0 == 1 and context.has_head %}{% set content = content + '<|extra_0|>' %}{% else %}{% set content = '<|startoftext|>' + content + '<|extra_0|>' %}{% endif %}{% elif message['role'] == 'assistant' %}{% set content = content + '<|eos|>' %}{% endif %}{{ content }}{% endfor %}
"""
glm4_chat_template = """
{%- for message in messages %} {%- if (message.role == "system") %} {{- '<|system|>' + '\n' + message.content }} {%- elif (message.role == "user") %} {{- '<|user|>' + '\n' + message.content }} {%- elif message.role == "assistant" %} {{- '<|assistant|>' }} {%- if message.content %} {{- 'streaming_transcription\n' + message.content }} {%- endif %} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} {{- '<|assistant|>streaming_transcription\n' }} {%- endif %}
"""
if __name__ == "__main__":
from transformers import AutoTokenizer
chat = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]
# print("=" * 100)
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Instruct-2407")
# print(tokenizer.get_chat_template())
# message = tokenizer.apply_chat_template(chat, tokenize=False)
# print(message)
print("=" * 100)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-14B-Instruct")
print(tokenizer.get_chat_template())
message = tokenizer.apply_chat_template(chat, tokenize=False)
print(message)
print("=" * 100)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-70B-Instruct")
print(tokenizer.get_chat_template())
message = tokenizer.apply_chat_template(chat, tokenize=False)
print(message)
print("=" * 100)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
print(tokenizer.get_chat_template())
message = tokenizer.apply_chat_template(chat, tokenize=False)
print(message)
message = tokenizer.apply_chat_template(chat, tokenize=True)
print(message)
print("=" * 100)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
print(tokenizer.get_chat_template())
message = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, enable_thinking=True)
print(message)
message = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, enable_thinking=False)
print(message)
icon.png

64.4 KB

Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment