Commit c7c75720 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 6c21ac17
Pipeline #1518 failed with stages
in 0 seconds
.idea
chenyh
FROM image.sourcefind.cn:5000/gpu/admin/base/jupyterlab-pytorch:2.2.0-python3.10-cuda12.1-ubuntu22.04 as base
ARG IMAGE=qwen2-audio-instruct-bot
ARG IMAGE_UPPER=Qwen2-Audio
ARG BRANCH=gpu
RUN cd /root && git clone -b $BRANCH http://developer.hpccube.com/codes/chenpangpang/$IMAGE.git
WORKDIR /root/$IMAGE/$IMAGE_UPPER
RUN pip install -r demo/requirements_web_demo.txt
#########
# Prod #
#########
FROM image.sourcefind.cn:5000/gpu/admin/base/jupyterlab-pytorch:2.2.0-python3.10-cuda12.1-ubuntu22.04
ARG IMAGE=qwen2-audio-instruct-bot
ARG IMAGE_UPPER=Qwen2-Audio
COPY chenyh/$IMAGE/frpc_linux_amd64_v0.2 /opt/conda/lib/python3.10/site-packages/gradio/
RUN chmod +x /opt/conda/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.2
COPY chenyh/$IMAGE/Qwen/Qwen2-Audio-7B-Instruct /root/$IMAGE_UPPER/Qwen/Qwen2-Audio-7B-Instruct
COPY --from=base /opt/conda/lib/python3.10/site-packages /opt/conda/lib/python3.10/site-packages
COPY --from=base /root/$IMAGE/$IMAGE_UPPER /root/$IMAGE_UPPER
COPY --from=base /root/$IMAGE/启动器.ipynb /root/$IMAGE/start.sh /root/
This diff is collapsed.
This diff is collapsed.
echo $CUDA_VISIBLE_DEVICES
SERVER_PORT=9001
MASTER_ADDR=localhost
MASTER_PORT="3${SERVER_PORT}"
NNODES=${WORLD_SIZE:-1}
NODE_RANK=${RANK:-0}
GPUS_PER_NODE=1
python -m torch.distributed.launch --use_env \
--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr=${MASTER_ADDR:-127.0.0.1} \
--master_port=$MASTER_PORT \
web_demo_audio.py \
--server-port ${SERVER_PORT}
\ No newline at end of file
gradio==4.31.3
modelscope-studio
\ No newline at end of file
import gradio as gr
import modelscope_studio as mgr
import librosa
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from argparse import ArgumentParser
DEFAULT_CKPT_PATH = 'Qwen/Qwen2-Audio-7B-Instruct'
def _get_args():
parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--inbrowser", action="store_true", default=False,
help="Automatically launch the interface in a new tab on the default browser.")
parser.add_argument("--server-port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
args = parser.parse_args()
return args
def add_text(chatbot, task_history, input):
text_content = input.text
content = []
if len(input.files) > 0:
for i in input.files:
content.append({'type': 'audio', 'audio_url': i.path})
if text_content:
content.append({'type': 'text', 'text': text_content})
task_history.append({"role": "user", "content": content})
chatbot.append([{
"text": input.text,
"files": input.files,
}, None])
return chatbot, task_history, None
def add_file(chatbot, task_history, audio_file):
"""Add audio file to the chat history."""
task_history.append({"role": "user", "content": [{"audio": audio_file.name}]})
chatbot.append((f"[Audio file: {audio_file.name}]", None))
return chatbot, task_history
def reset_user_input():
"""Reset the user input field."""
return gr.Textbox.update(value='')
def reset_state(task_history):
"""Reset the chat history."""
return [], []
def regenerate(chatbot, task_history):
"""Regenerate the last bot response."""
if task_history and task_history[-1]['role'] == 'assistant':
task_history.pop()
chatbot.pop()
if task_history:
chatbot, task_history = predict(chatbot, task_history)
return chatbot, task_history
def predict(chatbot, task_history):
"""Generate a response from the model."""
print(f"{task_history=}")
print(f"{chatbot=}")
text = processor.apply_chat_template(task_history, add_generation_prompt=True, tokenize=False)
audios = []
for message in task_history:
if isinstance(message["content"], list):
for ele in message["content"]:
if ele["type"] == "audio":
audios.append(
librosa.load(ele['audio_url'], sr=processor.feature_extractor.sampling_rate)[0]
)
if len(audios)==0:
audios=None
print(f"{text=}")
print(f"{audios=}")
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
if not _get_args().cpu_only:
inputs["input_ids"] = inputs.input_ids.to("cuda")
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(f"{response=}")
task_history.append({'role': 'assistant',
'content': response})
chatbot.append((None, response)) # Add the response to chatbot
return chatbot, task_history
def _launch_demo(args):
with gr.Blocks() as demo:
gr.Markdown(
"""<p align="center"><img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/blog/qwenaudio/qwen2audio_logo.png" style="height: 80px"/><p>""")
gr.Markdown("""<center><font size=8>Qwen2-Audio-Instruct Bot</center>""")
gr.Markdown(
"""\
<center><font size=3>This WebUI is based on Qwen2-Audio-Instruct, developed by Alibaba Cloud. \
(本WebUI基于Qwen2-Audio-Instruct打造,实现聊天机器人功能。)</center>""")
gr.Markdown("""\
<center><font size=4>Qwen2-Audio <a href="https://modelscope.cn/models/qwen/Qwen2-Audio-7B">🤖 </a>
| <a href="https://huggingface.co/Qwen/Qwen2-Audio-7B">🤗</a>&nbsp |
Qwen2-Audio-Instruct <a href="https://modelscope.cn/models/qwen/Qwen2-Audio-7B-Instruct">🤖 </a> |
<a href="https://huggingface.co/Qwen/Qwen2-Audio-7B-Instruct">🤗</a>&nbsp |
&nbsp<a href="https://github.com/QwenLM/Qwen2-Audio">Github</a></center>""")
chatbot = mgr.Chatbot(label='Qwen2-Audio-7B-Instruct', elem_classes="control-height", height=750)
user_input = mgr.MultimodalInput(
interactive=True,
sources=['microphone', 'upload'],
submit_button_props=dict(value="🚀 Submit (发送)"),
upload_button_props=dict(value="📁 Upload (上传文件)", show_progress=True),
)
task_history = gr.State([])
with gr.Row():
empty_bin = gr.Button("🧹 Clear History (清除历史)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
user_input.submit(fn=add_text,
inputs=[chatbot, task_history, user_input],
outputs=[chatbot, task_history, user_input]).then(
predict, [chatbot, task_history], [chatbot, task_history], show_progress=True
)
empty_bin.click(reset_state, outputs=[chatbot, task_history], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot, task_history], show_progress=True)
demo.queue().launch(
share=True,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
if __name__ == "__main__":
args = _get_args()
if args.cpu_only:
device_map = "cpu"
else:
device_map = "auto"
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint_path,
torch_dtype="auto",
device_map=device_map,
resume_download=True,
).eval()
model.generation_config.max_new_tokens = 2048 # For chat.
print("generation_config", model.generation_config)
processor = AutoProcessor.from_pretrained(args.checkpoint_path, resume_download=True)
_launch_demo(args)
## Evaluation
### Dependencies
```bash
apt-get update
apt-get install openjdk-8-jdk
pip install evaluate
pip install sacrebleu==1.5.1
pip install edit_distance
pip install editdistance
pip install jiwer
pip install scikit-image
pip install textdistance
pip install sed_eval
pip install more_itertools
pip install zhconv
```
### ASR
- Data
> LibriSpeech: https://www.openslr.org/12
> Aishell2: https://www.aishelltech.com/aishell_2
> common voice 15: https://commonvoice.mozilla.org/en/datasets
> Fluers: https://huggingface.co/datasets/google/fleurs
```bash
mkdir -p data/asr && cd data/asr
# download audios from above links
# download converted files
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/librispeech_eval.jsonl
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/aishell2_eval.jsonl
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/cv15_asr_en_eval.jsonl
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/cv15_asr_zh_eval.jsonl
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/cv15_asr_yue_eval.jsonl
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/cv15_asr_fr_eval.jsonl
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/fleurs_asr_zh_eval.jsonl
cd ../..
```
```bash
for ds in "librispeech" "aishell2" "cv15_en" "cv15_zh" "cv15_yue" "cv15_fr" "fluers_zh"
do
python -m torch.distributed.launch --use_env \
--nproc_per_node ${NPROC_PER_NODE:-8} --nnodes 1 \
evaluate_asr.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 20 \
--num-workers 2
done
```
### S2TT
- Data
> CoVoST 2: https://github.com/facebookresearch/covost
```bash
mkdir -p data/st && cd data/st
# download audios from https://commonvoice.mozilla.org/en/datasets
# download converted files
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/covost2_eval.jsonl
cd ../..
```
- Evaluate
```bash
ds="covost2"
python -m torch.distributed.launch --use-env \
--nproc_per_node ${NPROC_PER_NODE:-8} --nnodes 1 \
evaluate_st.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 8 \
--num-workers 2
```
### SER
- Data
> MELD: https://affective-meld.github.io/
```bash
mkdir -p data/ser && cd data/ser
# download MELD datasets from above link
# download converted files
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/meld_eval.jsonl
cd ../..
```
- Evaluate
```bash
ds="meld"
python -m torch.distributed.launch --use-env \
--nproc_per_node ${NPROC_PER_NODE:-8} --nnodes 1 \
evaluate_emotion.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 8 \
--num-workers 2
```
### VSC
- Data
> VocalSound: https://github.com/YuanGongND/vocalsound
```bash
mkdir -p data/vsc && cd data/vsc
# download dataset from the above link
# download converted files
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/vocalsound_eval.jsonl
cd ../..
```
- Evaluate
```bash
ds="vocalsound"
python -m torch.distributed.launch --use-env \
--nproc_per_node ${NPROC_PER_NODE:-8} --nnodes 1 \
evaluate_aqa.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 8 \
--num-workers 2
```
### AIR-BENCH
- Data
> AIR-BENCH: https://huggingface.co/datasets/qyang1021/AIR-Bench-Dataset
```bash
mkdir -p data/airbench && cd data/airbench
# download dataset from the above link
# download converted files
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/evaluation/airbench_level_3_eval.jsonl
cd ../..
```
```bash
ds="airbench_level3"
python -m torch.distributed.launch --use-env \
--nproc_per_node ${NPROC_PER_NODE:-8} --nnodes 1 \
evaluate_chat.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 8 \
--num-workers 2
```
### Acknowledgement
Part of these codes are borrowed from [Whisper](https://github.com/openai/whisper) , [speechio](https://github.com/speechio/chinese_text_normalization), thanks for their wonderful work.
\ No newline at end of file
This diff is collapsed.
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import re
from evaluate_tokenizer import EvaluationTokenizer
import editdistance as ed
import torch
from transformers.pipelines.audio_utils import ffmpeg_read
import requests
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
from cn_tn import TextNorm
import zhconv
english_normalizer = EnglishTextNormalizer()
chinese_normalizer = TextNorm(
to_banjiao = False,
to_upper = False,
to_lower = False,
remove_fillers = False,
remove_erhua =False,
check_chars = False,
remove_space = False,
cc_mode = '',
)
basic_normalizer = BasicTextNormalizer()
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
PUNCS = '!,.?;:'
ds_collections = {
'librispeech': {'path': 'asr/librispeech_eval.jsonl','language': 'en'},
'aishell2': {'path': 'asr/aishell2_eval.jsonl', 'language': 'zh'},
'cv15_en': {'path': 'asr/cv15_asr_en_eval.jsonl', 'language': 'en'},
'cv15_zh': {'path': 'asr/cv15_asr_zh_eval.jsonl', 'language': 'zh'},
'cv15_yue': {'path': 'asr/cv15_asr_yue_eval.jsonl', 'language': 'yue'},
'cv15_fr': {'path': 'asr/cv15_asr_fr_eval.jsonl', 'language': 'fr'},
'fluers_zh': {'path': 'asr/fleurs_asr_zh_eval.jsonl', 'language': 'zh'},
}
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, ds):
path = ds['path']
self.datas = open(path).readlines()
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
audio = data['audio']
source = data['source']
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt']
gt = data['gt']
return {
'audio': audio,
'prompt': prompt,
'source': source,
'gt': gt
}
def read_audio(audio_path):
if audio_path.startswith("http://") or audio_path.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(audio_path).content
else:
with open(audio_path, "rb") as f:
inputs = f.read()
return inputs
def collate_fn(inputs, processor):
input_texts = [_['prompt'] for _ in inputs]
source = [_['source'] for _ in inputs]
gt = [_['gt'] for _ in inputs]
audio_path = [_['audio'] for _ in inputs]
input_audios = [ffmpeg_read(read_audio(_['audio']),sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs]
inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
return inputs, audio_path, source, gt
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def remove_sp(text, language):
gt = re.sub(r"<\|.*?\|>", " ", text)
gt = re.sub(rf"\s+", r" ", gt) # 将文本中的连续空格替换为单个空格
gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
gt = gt.lstrip(" ")
if language == "zh":
gt = re.sub(rf"\s+", r"", gt)
return gt
def compute_wer(refs, hyps, language):
distance = 0
ref_length = 0
tokenizer = EvaluationTokenizer(
tokenizer_type="none",
lowercase=True,
punctuation_removal=True,
character_tokenization=False,
)
for i in range(len(refs)):
ref = refs[i]
pred = hyps[i]
if language in ["yue"]:
ref = zhconv.convert(ref, 'zh-cn')
pred = zhconv.convert(pred, 'zh-cn')
if language in ["en"]:
ref = english_normalizer(ref)
pred = english_normalizer(pred)
if language in ["zh"]:
ref = chinese_normalizer(ref)
pred = chinese_normalizer(pred)
else:
ref = basic_normalizer(ref)
pred = basic_normalizer(pred)
ref_items = tokenizer.tokenize(ref).split()
pred_items = tokenizer.tokenize(pred).split()
if language in ["zh", "yue"]:
ref_items = [x for x in "".join(ref_items)]
pred_items = [x for x in "".join(pred_items)]
if i==0:
print(f"ref: {ref}")
print(f"pred: {pred}")
print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
distance += ed.eval(ref_items, pred_items)
ref_length += len(ref_items)
return distance/ref_length
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint, device_map='cuda', torch_dtype='auto', trust_remote_code=True).eval()
processor = AutoProcessor.from_pretrained(args.checkpoint)
processor.tokenizer.padding_side = 'left'
random.seed(args.seed)
dataset = AudioDataset(
ds=ds_collections[args.dataset],
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, processor=processor),
)
gts = []
sources = []
rets = []
audio_paths = []
for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)):
inputs['input_ids'] = inputs['input_ids'].to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)
output_ids = output_ids[:, inputs.input_ids.size(1):]
output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
gts.extend(gt)
rets.extend(output)
sources.extend(source)
audio_paths.extend(audio_path)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_gts = [None for _ in range(world_size)]
merged_sources = [None for _ in range(world_size)]
merged_responses = [None for _ in range(world_size)]
merged_audio_paths = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_gts, gts)
torch.distributed.all_gather_object(merged_sources, sources)
torch.distributed.all_gather_object(merged_responses, rets)
torch.distributed.all_gather_object(merged_audio_paths, audio_paths)
merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)]
merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)]
merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)]
merged_responses = [
_ for _ in itertools.chain.from_iterable(merged_responses)
]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {args.dataset} ...")
results = []
for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths):
results.append({
'gt': gt,
'response': response,
'source': source,
'audio_path': audio_path,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
json.dump(results, open(results_file, 'w'))
results_dict = {}
for item in tqdm(results):
source = item["source"]
results_dict.setdefault(source, []).append(item)
lan = ds_collections[args.dataset]['language']
for source in results_dict:
refs, hyps = [], []
results_list = results_dict[source]
for result in results_list:
gt = result["gt"]
response = result["response"]
gt = remove_sp(gt, lan)
response = remove_sp(response, lan)
refs.append(gt)
hyps.append(response)
wer = compute_wer(refs, hyps, lan)
print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}")
torch.distributed.barrier()
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import torch
import requests
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from transformers.pipelines.audio_utils import ffmpeg_read
ds_collections = {
'airbench_level3': {'path': 'chat/airbench-level-3.jsonl'}
}
class AudioChatDataset(torch.utils.data.Dataset):
def __init__(self, ds):
path = ds['path']
self.datas = open(path).readlines()
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
audio = data['audio']
data_idx = data['id']
query = data['query']
return {
'audio': audio,
'data_idx': data_idx,
'query': query,
}
def read_audio(audio_path):
if audio_path.startswith("http://") or audio_path.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(audio_path).content
else:
with open(audio_path, "rb") as f:
inputs = f.read()
return inputs
def collate_fn(inputs, processor):
text_list = []
for _ in inputs:
query = _['query']
conversation = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': query}
]
text = processor.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors='pt',
tokenize=False
)
text_list.append(text)
audio_path = [_['audio'] for _ in inputs]
data_idxs = [_['data_idx'] for _ in inputs]
input_audios = [ffmpeg_read(read_audio(_['audio']), sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs]
inputs = processor(text=text_list, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
return inputs, audio_path, data_idxs
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio-7B-Instruct')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint, device_map='cuda', torch_dtype='auto', trust_remote_code=True).eval()
processor = AutoProcessor.from_pretrained(args.checkpoint)
processor.tokenizer.padding_side = 'left'
random.seed(args.seed)
dataset = AudioChatDataset(
ds=ds_collections[args.dataset],
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, processor=processor),
)
idxs = []
rets = []
audio_paths = []
for _, (inputs, audio_path, data_idxs) in tqdm(enumerate(data_loader)):
inputs['input_ids'] = inputs['input_ids'].to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1,do_sample=False)
output_ids = output_ids[:, inputs.input_ids.size(1):]
output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
rets.extend(output)
audio_paths.extend(audio_path)
idxs.extend(data_idxs)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_idxs = [None for _ in range(world_size)]
merged_responses = [None for _ in range(world_size)]
merged_audio_paths = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_idxs, idxs)
torch.distributed.all_gather_object(merged_responses, rets)
torch.distributed.all_gather_object(merged_audio_paths, audio_paths)
merged_idxs = [_ for _ in itertools.chain.from_iterable(merged_idxs)]
merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)]
merged_responses = [
_ for _ in itertools.chain.from_iterable(merged_responses)
]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {args.dataset} ...")
results = []
for idx, response, audio_path in zip(merged_idxs, merged_responses, merged_audio_paths):
results.append({
'idx': idx,
'response': response,
'audio_path': audio_path,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
json.dump(results, open(results_file, 'w'))
torch.distributed.barrier()
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import torch
import requests
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from transformers.pipelines.audio_utils import ffmpeg_read
from sklearn.metrics import accuracy_score
ds_collections = {
'meld': {'path': 'ser/meld_eval.jsonl'}
}
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, ds):
path = ds['path']
self.datas = open(path).readlines()
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
audio = data['audio']
source = data['source']
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt']
gt = data['gt']
return {
'audio': audio,
'prompt': prompt,
'source': source,
'gt': gt
}
def read_audio(audio_path):
if audio_path.startswith("http://") or audio_path.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(audio_path).content
else:
with open(audio_path, "rb") as f:
inputs = f.read()
return inputs
def collate_fn(inputs, processor):
input_texts = [_['prompt'] for _ in inputs]
source = [_['source'] for _ in inputs]
gt = [_['gt'] for _ in inputs]
audio_path = [_['audio'] for _ in inputs]
input_audios = [ffmpeg_read(read_audio(_['audio']), sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs]
inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
return inputs, audio_path, source, gt
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio-7B')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint, device_map='cuda', trust_remote_code=True, torch_dtype='auto').eval()
processor = AutoProcessor.from_pretrained(args.checkpoint)
processor.tokenizer.padding_side = 'left'
random.seed(args.seed)
dataset = AudioDataset(
ds=ds_collections[args.dataset],
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, processor=processor),
)
gts = []
sources = []
rets = []
audio_paths = []
for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)):
inputs['input_ids'] = inputs['input_ids'].to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)
output_ids = output_ids[:, inputs.input_ids.size(1):]
output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
gts.extend(gt)
rets.extend(output)
sources.extend(source)
audio_paths.extend(audio_path)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_gts = [None for _ in range(world_size)]
merged_sources = [None for _ in range(world_size)]
merged_responses = [None for _ in range(world_size)]
merged_audio_paths = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_gts, gts)
torch.distributed.all_gather_object(merged_sources, sources)
torch.distributed.all_gather_object(merged_responses, rets)
torch.distributed.all_gather_object(merged_audio_paths, audio_paths)
merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)]
merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)]
merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)]
merged_responses = [
_ for _ in itertools.chain.from_iterable(merged_responses)
]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {args.dataset} ...")
results = []
for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths):
results.append({
'gt': gt,
'response': response,
'source': source,
'audio_path': audio_path,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
json.dump(results, open(results_file, 'w'))
results_dict = {}
for item in tqdm(results):
source = item["source"]
results_dict.setdefault(source, []).append(item)
for source in results_dict:
refs, hyps = [], []
bi_refs, bi_hyps = [], []
results_list = results_dict[source]
for result in results_list:
gt = result["gt"]
response = result["response"].lstrip()
refs.append(gt)
hyps.append(response)
score = accuracy_score(refs, hyps)
print(f"{source} ACC_score:", score, len(hyps))
torch.distributed.barrier()
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import sacrebleu
import torch
import requests
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from transformers.pipelines.audio_utils import ffmpeg_read
ds_collections = {
'covost2': {'path': 'st/covost2_eval.jsonl'}
}
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, ds):
path = ds['path']
self.datas = open(path).readlines()
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
audio = data['audio']
source = data['source']
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt']
gt = data['gt']
return {
'audio': audio,
'prompt': prompt,
'source': source,
'gt': gt
}
def read_audio(audio_path):
if audio_path.startswith("http://") or audio_path.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(audio_path).content
else:
with open(audio_path, "rb") as f:
inputs = f.read()
return inputs
def collate_fn(inputs, processor):
input_texts = [_['prompt'] for _ in inputs]
source = [_['source'] for _ in inputs]
gt = [_['gt'] for _ in inputs]
audio_path = [_['audio'] for _ in inputs]
input_audios = [ffmpeg_read(read_audio(_['audio']),sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs]
inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
return inputs, audio_path, source, gt
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio-7B')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint, device_map='cuda', trust_remote_code=True, torch_dtype='auto').eval()
processor = AutoProcessor.from_pretrained(args.checkpoint)
processor.tokenizer.padding_side = 'left'
random.seed(args.seed)
dataset = AudioDataset(
ds=ds_collections[args.dataset],
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, processor=processor),
)
gts = []
sources = []
rets = []
audio_paths = []
for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)):
inputs['input_ids'] = inputs['input_ids'].to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)
output_ids = output_ids[:, inputs.input_ids.size(1):]
output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
gts.extend(gt)
rets.extend(output)
sources.extend(source)
audio_paths.extend(audio_path)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_gts = [None for _ in range(world_size)]
merged_sources = [None for _ in range(world_size)]
merged_responses = [None for _ in range(world_size)]
merged_audio_paths = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_gts, gts)
torch.distributed.all_gather_object(merged_sources, sources)
torch.distributed.all_gather_object(merged_responses, rets)
torch.distributed.all_gather_object(merged_audio_paths, audio_paths)
merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)]
merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)]
merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)]
merged_responses = [
_ for _ in itertools.chain.from_iterable(merged_responses)
]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {args.dataset} ...")
results = []
for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths):
results.append({
'gt': gt,
'response': response,
'source': source,
'audio_path': audio_path,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
json.dump(results, open(results_file, 'w'))
results_dict = {}
for item in tqdm(results):
source = item["source"]
results_dict.setdefault(source, []).append(item)
for source in results_dict:
text_lan = source.split("_")[-2]
if text_lan == "ja":
text_lan = "ja-mecab"
elif text_lan == "zh":
text_lan = "zh"
else:
text_lan = "13a"
refs, hyps = [], []
results_list = results_dict[source]
for result in results_list:
gt = result["gt"]
response = result["response"]
refs.append(gt)
hyps.append(response)
bleu = sacrebleu.corpus_bleu(hyps,[refs], tokenize=text_lan).score
print(f"source: {source} cnt: {len(refs)} bleu score: {bleu:.4f}")
torch.distributed.barrier()
# Copyright 2022 The OFA-Sys Team. All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import unicodedata
class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
lowercasing, punctuation removal and character tokenization, which are
applied after sacreBLEU tokenization.
Args:
tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
lowercase (bool): lowercase the text.
punctuation_removal (bool): remove punctuation (based on unicode
category) from text.
character_tokenization (bool): tokenize the text to characters.
"""
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
# ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
def __init__(
self,
tokenizer_type: str = "13a",
lowercase: bool = False,
punctuation_removal: bool = False,
character_tokenization: bool = False,
):
from sacrebleu.tokenizers import TOKENIZERS
assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
self.lowercase = lowercase
self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization
self.tokenizer = TOKENIZERS[tokenizer_type]
@classmethod
def remove_punctuation(cls, sent: str):
"""Remove punctuation based on Unicode category."""
return cls.SPACE.join(
t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t)
)
def tokenize(self, sent: str):
tokenized = self.tokenizer()(sent)
if self.punctuation_removal:
tokenized = self.remove_punctuation(tokenized)
if self.character_tokenization:
tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)))
if self.lowercase:
tokenized = tokenized.lower()
return tokenized
\ No newline at end of file
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import torch
import requests
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from transformers.pipelines.audio_utils import ffmpeg_read
from sklearn.metrics import accuracy_score
ds_collections = {
'vocalsound': {'path': 'vsc/vocalsound_eval.jsonl'}
}
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, ds):
path = ds['path']
self.datas = open(path).readlines()
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
audio = data['audio']
source = data['source']
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt']
gt = data['gt']
return {
'audio': audio,
'prompt': prompt,
'source': source,
'gt': gt
}
def read_audio(audio_path):
if audio_path.startswith("http://") or audio_path.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(audio_path).content
else:
with open(audio_path, "rb") as f:
inputs = f.read()
return inputs
def collate_fn(inputs, processor):
input_texts = [_['prompt'] for _ in inputs]
source = [_['source'] for _ in inputs]
gt = [_['gt'] for _ in inputs]
audio_path = [_['audio'] for _ in inputs]
input_audios = [ffmpeg_read(read_audio(_['audio']),sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs]
inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
return inputs, audio_path, source, gt
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio-7B')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint, device_map='cuda', trust_remote_code=True, torch_dtype='auto').eval()
processor = AutoProcessor.from_pretrained(args.checkpoint)
processor.tokenizer.padding_side = 'left'
random.seed(args.seed)
dataset = AudioDataset(
ds=ds_collections[args.dataset],
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, processor=processor),
)
gts = []
sources = []
rets = []
audio_paths = []
for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)):
inputs['input_ids'] = inputs['input_ids'].to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)
output_ids = output_ids[:, inputs.input_ids.size(1):]
output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
gts.extend(gt)
rets.extend(output)
sources.extend(source)
audio_paths.extend(audio_path)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_gts = [None for _ in range(world_size)]
merged_sources = [None for _ in range(world_size)]
merged_responses = [None for _ in range(world_size)]
merged_audio_paths = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_gts, gts)
torch.distributed.all_gather_object(merged_sources, sources)
torch.distributed.all_gather_object(merged_responses, rets)
torch.distributed.all_gather_object(merged_audio_paths, audio_paths)
merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)]
merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)]
merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)]
merged_responses = [
_ for _ in itertools.chain.from_iterable(merged_responses)
]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {args.dataset} ...")
results = []
for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths):
results.append({
'gt': gt,
'response': response,
'source': source,
'audio_path': audio_path,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
json.dump(results, open(results_file, 'w'))
results_dict = {}
for item in tqdm(results):
source = item["source"]
results_dict.setdefault(source, []).append(item)
for source in results_dict:
refs, hyps = [], []
bi_refs, bi_hyps = [], []
results_list = results_dict[source]
for result in results_list:
gt = result["gt"]
response = result["response"].lstrip()
refs.append(gt)
hyps.append(response)
score = accuracy_score(refs, hyps)
print(f"{source} ACC_score:", score, len(hyps))
torch.distributed.barrier()
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"ẞ": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment