Commit b97afd54 authored by wangwei990215's avatar wangwei990215
Browse files

Initial commit

parents
Pipeline #1825 failed with stages
in 0 seconds
import os
import types
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration
from peft import PeftModel, get_peft_model
from mooer.utils.utils import print_module_size, compute_accuracy
from mooer.utils.config_utils import generate_peft_config
from mooer.models.encoder import WhisperWrappedEncoder, HubertEncoder, W2vBert2Encoder, ParaformerEncoder
from mooer.models.adapter import LinearAdapter
import logging
logger = logging.getLogger(__name__)
def init_model(model_config, train_config=None, peft_config=None):
tokenizer = setup_tokenizer(model_config)
encoder = setup_encoder(model_config, train_config)
adapter = setup_adapter(model_config, train_config)
llm = setup_llm(model_config, train_config, peft_config)
model = MooerModel(
encoder,
llm,
adapter,
tokenizer,
model_config,
train_config
)
# load adapter
ckpt_path = model_config.get("adapter_path", "")
if os.path.isdir(ckpt_path):
logger.info("CKPT: loading DeepSpeed Model from: {}".format(ckpt_path))
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
ckpt_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_path)
logging.info("Merge Zero3 model to FP32...")
#### check params if need
model_state_dict = model.state_dict()
missing_keys = [k for k in ckpt_dict.keys() if k not in model_state_dict]
for key in missing_keys:
logging.info(f"MISSING KEY: {key}")
model.load_state_dict(ckpt_dict, strict=False)
if model_config.get('save_lora_weights', False):
logging.info("Save Lora Weights...")
model.llm.save_pretrained(os.path.join(ckpt_path, 'new_llm'))
logging.info("Save finished...")
exit()
elif os.path.exists(ckpt_path):
logger.info("CKPT: loading other parts from: {}".format(ckpt_path))
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
model_state_dict = model.state_dict()
missing_keys = [k for k in ckpt_dict.keys() if k not in model_state_dict]
for key in missing_keys:
logging.info(f"MISSING KEY: {key}")
model.load_state_dict(ckpt_dict, strict=False)
return model, tokenizer
def setup_tokenizer(model_config):
if "qwen" in model_config.llm_name.lower():
tokenizer = AutoTokenizer.from_pretrained(
model_config.llm_path,
padding_side="right",
use_fast=False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
def setup_encoder(model_config, train_config=None):
encoder_name = model_config.encoder_name
if encoder_name == "whisper":
encoder = WhisperWrappedEncoder.load(model_config)
elif encoder_name == "hubert":
encoder = HubertEncoder.load(model_config)
elif encoder_name == "w2v_bert2.0":
encoder = W2vBert2Encoder.load(model_config)
elif encoder_name == "paraformer":
encoder = ParaformerEncoder.load(model_config)
else:
raise KeyError(f"not support encoder: {encoder_name}")
print_module_size(encoder, encoder_name, 0, "====Total Params====")
if train_config is None or train_config.freeze_encoder:
for name, param in encoder.named_parameters():
param.requires_grad = False
encoder.eval()
print_module_size(encoder, encoder_name, 0, "====Trainable Params====")
return encoder
def setup_llm(model_config, train_config=None, peft_config=None):
if model_config.load_dtype == "float16":
load_dtype = torch.float16
elif model_config.load_dtype == "bfloat16":
load_dtype = torch.bfloat16
else:
load_dtype = torch.float32
if "qwen" in model_config.llm_name.lower():
model = AutoModelForCausalLM.from_pretrained(
model_config.llm_path,
use_cache=None,
torch_dtype=load_dtype,
)
else:
# load your own LLM
model = AutoModelForCausalLM.from_pretrained(
model_config.llm_path,
use_cache=None,
torch_dtype=load_dtype,
)
print_module_size(model, model_config.llm_name, 0, "====Total Params====")
if train_config is None or train_config.freeze_llm:
for name, param in model.named_parameters():
param.requires_grad = False
model.eval()
if model_config.get("lora_dir", None) and os.path.exists(model_config.get("lora_dir", "")):
if model_config.is_inference:
logger.info("Inference load lora...")
logger.info("loading lora_dir from: {}".format(model_config.get("lora_dir")))
model = PeftModel.from_pretrained(model=model, model_id=model_config.get("lora_dir"), is_trainable=False)
logger.info("Start Merging LLM and Adaptor...")
model = model.merge_and_unload()
model.eval()
logger.info("Finish Merging LLM and Adaptor...")
else:
# continuous training
logger.info("Training load lora...")
logger.info("loading lora_dir from: {}".format(model_config.get("lora_dir")))
model = PeftModel.from_pretrained(model=model, model_id=model_config.get("lora_dir"), is_trainable=True)
logger.info("Start Merging LLM and Adaptor...")
model = model.merge_and_unload()
logger.info("Finish Merging LLM and Adaptor...")
elif train_config.use_peft:
assert peft_config is not None
logger.info("setup peft...")
peft_config = generate_peft_config(peft_config)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print_module_size(model, model_config.llm_name, 0, "====Trainable Params====")
return model
def setup_adapter(model_config, train_config=None):
if model_config.adapter == "linear":
adapter = LinearAdapter(model_config)
else:
raise KeyError(f"not support {model_config.adapter}")
print_module_size(adapter, model_config.adapter, 0, "====Total Params====")
if train_config is None or train_config.freeze_projector:
for name, param in adapter.named_parameters():
param.requires_grad = False
adapter.eval()
print_module_size(adapter, model_config.adapter, 0, "====Trainable Params====")
return adapter
class MooerModel(nn.Module):
def __init__(
self,
encoder: nn.Module,
llm: nn.Module,
adapter: nn.Module,
tokenizer,
model_config,
train_config=None
):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = adapter
self.tokenizer = tokenizer
self.model_config = model_config
self.train_config = train_config
if self.train_config is not None and train_config.get("enable_deepspeed", False):
def new_forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
for item in self.modules():
if isinstance(item, nn.LayerNorm):
item.forward = types.MethodType(new_forward, item)
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
):
audio_mel = kwargs.get("audio_mel", None)
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None)
audio = kwargs.get("audio", None)
audio_mask = kwargs.get("audio_mask", None)
visual = kwargs.get("visual", None)
modality_mask = kwargs.get("modality_mask", None)
# for paraformer
audio_mel_reallen = kwargs.get("audio_mel_reallen", None)
gradient_checkpoint = self.model_config.get("gradient_checkpoint", False)
encoder_outs = None
if audio_mel is not None or audio is not None or visual is not None:
self.encoder.eval()
if self.model_config.encoder_name == "whisper":
encoder_outs = self.encoder.extract_variable_length_features(
audio_mel.permute(0, 2, 1), gradient_checkpoint=gradient_checkpoint) # bs*seq*dim
if self.model_config.encoder_name == "hubert":
results = self.encoder(source=audio, padding_mask=1 - audio_mask)
if self.model_config.encoder_type == "pretrain":
encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
if self.model_config.encoder_type == "finetune":
encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
encoder_outs = encoder_outs.transpose(0, 1)
if self.model_config.encoder_name == 'w2v_bert2.0':
encoder_outs = self.encoder.extract_features(source=audio_mel, attention_mask=audio_mel_post_mask)
if self.model_config.encoder_name == 'paraformer':
encoder_outs = self.encoder.extract_features(source=audio_mel, reallen=audio_mel_reallen)
if self.encoder is None:
encoder_outs = audio_mel if audio_mel is not None else audio
if self.model_config.adapter == "linear":
encoder_outs = self.encoder_projector(encoder_outs, gradient_checkpoint=gradient_checkpoint)
if input_ids is not None:
input_ids[input_ids == -1] = 0
if isinstance(self.llm, T5ForConditionalGeneration):
inputs_embeds = self.llm.shared(input_ids)
else:
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
if modality_mask is not None:
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
encoder_outs_pad = torch.zeros_like(inputs_embeds)
for i in range(encoder_outs.shape[0]):
encoder_outs_pad[
i, modality_mask_start_indices[i]:modality_mask_start_indices[i] + modality_lengths[i]
] = encoder_outs[i][:modality_lengths[i]]
inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
if self.model_config.get("is_inference", False):
return inputs_embeds, attention_mask
else:
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
return model_outputs, acc
@torch.no_grad()
def generate(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
compute_llm=True,
**kwargs,
):
kwargs["inference_mode"] = True
inputs_embeds, attention_mask = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
if not compute_llm:
return inputs_embeds, attention_mask, kwargs
model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 4),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 0.8),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
sequence_bias={tuple([self.tokenizer.eos_token_id]):-0.2}
)
return model_outputs
import os
import time
import logging
import argparse
# nn
import torch
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
from mooer.models import mooer_model
from mooer.utils.utils import get_device
from mooer.utils.config_utils import parse_asr_configs
from mooer.datasets.speech_dataset_shard import SpeechDatasetShard
def parse_args():
parser = argparse.ArgumentParser(description='DeepSpeed Training Script')
parser.add_argument('--test_config', type=str, required=True, help='Path to the testing configuration file.')
parser.add_argument('--test_data_dir', type=str, default='', help='Path to the testing sets.')
parser.add_argument('--test_sets', type=str, default='', help='test_sets in test_data_dir, e.g, aishell1/aishell2')
parser.add_argument('--decode_path', type=str, required=True, help='Path to save decode text and compute wer')
args = parser.parse_args()
return args
def main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
args = parse_args()
device = str(get_device())
configs = parse_asr_configs(args.test_config)
train_config = configs['TrainConfig']
model_config = configs['ModelConfig']
dataset_config = configs['DataConfig']
# reset test epoch
logger.info("set epoch_num=1 for testing")
dataset_config.num_epochs = 1
# update paths
if os.path.exists(args.test_data_dir):
dataset_config.test_data_dir = args.test_data_dir
dataset_config.test_sets = args.test_sets
os.makedirs(args.decode_path, exist_ok=True)
model, tokenizer = mooer_model.init_model(
model_config=model_config,
train_config=train_config)
model.to(device)
model.eval()
# dataset_config = generate_dataset_config(train_config, kwargs)
logger.info("dataset_config: {}".format(dataset_config))
test_data_dir = dataset_config.test_data_dir
test_sets = dataset_config.test_sets
decode_path = args.decode_path
for test_set in test_sets.strip().split('/'):
test_set_path = os.path.join(test_data_dir, test_set, "data.list")
decode_dir = os.path.join(decode_path, test_set)
os.makedirs(decode_dir, exist_ok=True)
logging.info(f"Test for {test_set_path}")
if dataset_config.get('test_data_type', 'shard') == 'shard':
logging.info("Use shard for training...")
dataset_test_items = SpeechDatasetShard(dataset_config=dataset_config,
tokenizer=tokenizer,
normalize=dataset_config.normalize,
mel_size=dataset_config.mel_size)
dataset_test = dataset_test_items.dataset(
data_type='shard',
data_list_file=test_set_path,
shuffle=False,
partition=False
)
collator = dataset_test_items.collator
test_dataloader = torch.utils.data.DataLoader(
dataset_test,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
shuffle=False,
batch_size=train_config.val_batch_size,
drop_last=False,
collate_fn=collator
)
else:
raise KeyError
logger.info("=====================================")
pred_path = os.path.join(decode_dir, 'text')
ss = time.perf_counter()
dtype = torch.float32
if train_config.use_fp16:
dtype = torch.float16
elif train_config.use_bf16:
dtype = torch.bfloat16
logging.info(f"Input data type: {dtype}")
with torch.no_grad():
with open(pred_path, "w") as pred:
for step, batch in enumerate(test_dataloader):
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with torch.cuda.amp.autocast(dtype=dtype):
model_outputs = model.generate(**batch)
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
for key, text, target in zip(batch["keys"], output_text, batch["targets"]):
logging.info(f"{key} {text}")
pred.write(key + "\t" + text + "\n")
logging.info(f"Infer {test_set} Cost: {time.perf_counter() - ss}")
if __name__ == "__main__":
main()
import os
import random
import logging
import argparse
import deepspeed
# nn
import torch
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
from mooer.models import mooer_model
from mooer.utils.utils import get_device
from mooer.utils.config_utils import parse_asr_configs
from mooer.utils.train_utils import train, clear_gpu_cache
from mooer.datasets.speech_dataset_shard import SpeechDatasetShard
def parse_args():
parser = argparse.ArgumentParser(description='DeepSpeed Training Script')
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument('--training_config', type=str, required=True, help='Path to the training configuration file.')
args = parser.parse_args()
return args
def main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
args = parse_args()
device = str(get_device())
configs = parse_asr_configs(args.training_config)
train_config = configs['TrainConfig']
model_config = configs['ModelConfig']
dataset_config = configs['DataConfig']
peft_config = configs['PeftConfig']
deepspeed_config = train_config.deepspeed_config
logger = logging.getLogger()
logger.setLevel(logging.INFO)
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
logger.info(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}")
if 'musa' in device:
torch.musa.manual_seed(train_config.seed + train_config.resume_epoch)
torch.manual_seed(train_config.seed + train_config.resume_epoch)
random.seed(train_config.seed + train_config.resume_epoch)
torch.musa.set_device(local_rank)
else:
torch.cuda.manual_seed(train_config.seed + train_config.resume_epoch)
torch.manual_seed(train_config.seed + train_config.resume_epoch)
random.seed(train_config.seed + train_config.resume_epoch)
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
model, tokenizer = mooer_model.init_model(
model_config=model_config,
train_config=train_config, peft_config=peft_config)
parameters = filter(lambda p: p.requires_grad, model.parameters())
model_engine, _, _, _ = deepspeed.initialize(
model=model,
model_parameters=parameters,
config=deepspeed_config,
)
if dataset_config.get('train_data_type', 'shard') == 'shard':
logging.info("Use shard for training...")
dataset_train_items = SpeechDatasetShard(dataset_config=dataset_config,
tokenizer=tokenizer,
normalize=dataset_config.normalize,
mel_size=dataset_config.mel_size)
dataset_train = dataset_train_items.dataset(
data_type='shard',
data_list_file=dataset_config['train_data_path'],
shuffle=True,
partition=True
)
train_dl_kwargs = {
'batch_size': train_config.batch_size_training,
'drop_last': True,
'collate_fn': dataset_train_items.collator
}
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**train_dl_kwargs,
)
else:
raise KeyError
# Start the training process
train(
model_engine,
train_dataloader,
train_config,
local_rank=local_rank,
rank=rank,
train_data_set=dataset_train if dataset_config.get('train_data_type', 'shard') == 'shard' else None,
model_org=model
)
if __name__ == "__main__":
main()
import os
import logging
import torch
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
logger = logging.getLogger(__name__)
def save_model_checkpoint_deepspeed(model_engine, cfg, checkpoint_name="checkpoint", merge_rank=False, model=None):
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, checkpoint_name)
os.makedirs(save_dir, exist_ok=True)
save_full_path = save_dir
model_engine.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True)
logger.info(f"encoder saved at {save_full_path}")
# if merged, it will be fast when decoding
if merge_rank:
assert model is not None
save_dir_merge = os.path.join(cfg.output_dir, checkpoint_name + '_merged')
os.makedirs(save_dir_merge, exist_ok=True)
logger.info("CKPT: loading DeepSpeed Model from: {}".format(save_full_path))
ckpt_dict = get_fp32_state_dict_from_zero_checkpoint(save_full_path)
logging.info("Merge Zero3 model to FP32...")
logging.info("Save Lora Weights...")
model.llm.save_pretrained(os.path.join(save_dir_merge, 'new_llm'))
logging.info(f"Save finished... {os.path.join(save_dir_merge, 'new_llm')}")
ckpt_dict_new = {}
for key in ckpt_dict.keys():
if 'llm' not in key:
ckpt_dict_new[key] = ckpt_dict[key].to('cpu').clone()
torch.save(ckpt_dict_new, os.path.join(save_dir_merge, 'adapter_project.pt'))
logging.info(f"Save finished... {os.path.join(save_dir_merge, 'adapter_project.pt')}")
import sys
import os
import logging
from importlib import import_module
from peft import (
LoraConfig,
AdaptionPromptConfig,
PrefixTuningConfig,
)
def parse_asr_configs(file_path):
file_dir = os.path.dirname(file_path)
module_name = os.path.splitext(os.path.basename(file_path))[0]
sys.path.insert(0, file_dir)
module = import_module(module_name)
sys.path.pop(0)
ModelConfig = getattr(module, 'ModelConfig', None)
PeftConfig = getattr(module, 'PeftConfig', None)
TrainConfig = getattr(module, 'TrainConfig', None)
DataConfig = getattr(module, 'DataConfig', None)
update_function = getattr(module, 'update', None)
if None in (ModelConfig, PeftConfig, TrainConfig, DataConfig, update_function):
raise ImportError(f"Could not find all expected classes or function in {file_path}")
model_config_instance = ModelConfig()
peft_config_instance = PeftConfig()
train_config_instance = TrainConfig()
data_config_instance = DataConfig()
# update something
update_function(model_config_instance,
train_config_instance,
data_config_instance)
items = {
'ModelConfig': model_config_instance,
'PeftConfig': peft_config_instance,
'TrainConfig': train_config_instance,
'DataConfig': data_config_instance,
'update_function': update_function
}
for key in items.keys():
logging.info(f"################# {key} #################")
instance = items[key]
if isinstance(instance, (ModelConfig, PeftConfig, TrainConfig, DataConfig)):
for attr_name, attr_value in vars(instance).items():
logging.info(f"{attr_name}: {attr_value}")
return items
def generate_peft_config(peft_config):
peft_configs = {"lora": LoraConfig,
"llama_adapter": AdaptionPromptConfig,
"prefix": PrefixTuningConfig
}
params = {}
for attr_name, attr_value in vars(peft_config).items():
params[attr_name] = attr_value
params.pop("peft_method", None)
peft_config_parse = peft_configs[peft_config.get("peft_method", "lora")](**params)
return peft_config_parse
PROMPT_TEMPLATE_DICT = {
'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
}
PROMPT_DICT = {
'asr': "Transcribe speech to text. ",
'ast': "Translate speech to english text. ",
}
\ No newline at end of file
import time
from contextlib import nullcontext
import torch.distributed as dist
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
from mooer.utils.utils import *
from mooer.utils.checpoint_io import save_model_checkpoint_deepspeed
logger = logging.getLogger(__name__)
device = str(get_device())
def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
logger.info(f"Clearing GPU cache for all ranks")
if 'musa' in device:
torch.musa.empty_cache()
else:
torch.cuda.empty_cache()
def train(
model,
train_dataloader,
train_config,
local_rank=None,
rank=None,
train_data_set=None,
model_org=None
):
epoch_times = []
total_step = 0
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_context = model.join
else:
model_context = nullcontext
for epoch in range(train_config.resume_epoch, train_config.num_epochs):
if train_data_set is not None:
dist.barrier()
logging.info(f"RANK:{rank} Reset Dataset Epoch {epoch}...")
train_data_set.set_epoch(epoch)
train_dataloader_iterator = iter(train_dataloader)
dist.barrier()
epoch_start_time = time.perf_counter()
model.train()
total_loss = 0.0
total_acc = 0.0
step = 0
input_dtype = torch.float32
if train_config.use_fp16:
input_dtype = torch.float16
elif train_config.use_bf16:
input_dtype = torch.bfloat16
logging.info(f"Input data type: {input_dtype}")
with model_context():
should_continue = True
while should_continue:
try:
batch = next(train_dataloader_iterator)
total_step += 1
for key in batch.keys():
batch[key] = (
batch[key].to(local_rank).to(input_dtype)
if isinstance(batch[key], torch.Tensor)
and batch[key].dtype == torch.float32
else (
batch[key].to(local_rank)
if isinstance(batch[key], torch.Tensor)
else batch[key]
)
)
outputs, acc = model(**batch)
loss = outputs.loss
acc_report = acc
loss_report = loss.detach().float()
total_loss += loss.detach().float()
total_acc += acc
model.backward(loss)
model.step()
current_lr = model.optimizer.param_groups[0]['lr']
if rank == 0 and step % train_config.log_interval == 0:
logging.info(
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, "
f"step {step} lr {current_lr} "
f"completed (loss: {loss_report}, "
f"acc: {acc_report})")
if step % train_config.save_interval == 0:
checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch + 1)}_step_{step + 1}"
save_model_checkpoint_deepspeed(
model, train_config, checkpoint_name,
merge_rank=train_config.get('save_merge_rank', True),
model=model_org
)
step += 1
except Exception as e:
logging.error(f"Exception occurred on Rank {rank}: {e}")
epoch_end_time = time.perf_counter() - epoch_start_time
logging.info(f"Epoch {epoch + 1}, Cost Time: {epoch_end_time}")
epoch_times.append(epoch_end_time)
# save model
checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch + 2)}_step_1"
save_model_checkpoint_deepspeed(
model, train_config, checkpoint_name,
merge_rank=train_config.get('save_merge_rank', True),
model=model_org
)
break
import logging
import re
import torch
def parse_key_text(input_text):
result = {}
with open(input_text, 'r') as r:
for line in r.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) != 2:
continue
key, text = line
result[key] = text
return result
def print_module_size(module, module_name, rank: int = 0, info=None) -> None:
if rank == 0:
if info:
logging.info(info)
logging.info(f"--> Module {module_name}")
total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
logging.info(f"--> {module_name} has {total_params / 1e6} Million params\n")
# device register/check/create by default
_device_registry = []
def _enable_cuda() -> bool:
return torch.cuda.is_available()
def _enable_musa() -> bool:
try:
import torch_musa
except:
return False
return torch_musa.is_available()
def _create_cuda_device() -> torch.device:
return torch.device("cuda")
def _create_musa_device() -> torch.device:
return torch.device("musa")
def _register_device(priority, checker, creator):
device_elem = (priority, checker, creator)
_device_registry.append(device_elem)
_device_registry.sort()
_register_device(10, _enable_musa, _create_musa_device)
_register_device(20, _enable_cuda, _create_cuda_device)
def get_device() -> torch.device:
for (_, checker, creator) in _device_registry:
if checker():
return creator()
return torch.device("cpu")
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float()
def extract_audio_token_from_string(input_str):
pattern = r"<A_(\d+)>"
matches = re.findall(pattern, input_str)
return [int(i) for i in matches]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re, sys, unicodedata
import codecs
remove_tag = True
spacelist= [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?',
'、', '。', '!', ',', ';', '?',
':', '「', '」', '︰', '『', '』', '《', '》']
def characterize(string) :
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i+1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c==sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0; T=len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
else:
new_sentence.append(x)
return new_sentence
class Calculator :
def __init__(self) :
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec) :
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab) :
self.space.append([])
for row in self.space :
for element in row :
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec) :
row.append({'dist' : 0, 'error' : 'non'})
for i in range(len(lab)) :
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)) :
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
for token in rec :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
# Computing edit distance
for i, lab_token in enumerate(lab) :
for j, rec_token in enumerate(rec) :
if i == 0 or j == 0 :
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i-1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist :
min_dist = dist
min_error = error
dist = self.space[i][j-1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist :
min_dist = dist
min_error = error
if lab_token == rec_token :
dist = self.space[i-1][j-1]['dist'] + self.cost['cor']
error = 'cor'
else :
dist = self.space[i-1][j-1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist :
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
i = len(lab) - 1
j = len(rec) - 1
while True :
if self.space[i][j]['error'] == 'cor' : # correct
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub' : # substitution
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del' : # deletion
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins' : # insertion
if len(rec[j]) > 0 :
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non' : # starting point
break
else : # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
return result
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in data :
if token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self) :
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ]
for i in reversed(range(len(unicode_names))) :
if unicode_names[i].startswith('DIGIT') : # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) :
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')) :
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or
unicode_names[i].startswith('COMMERCIAL AT') or
unicode_names[i].startswith('DEGREE CELSIUS') or
unicode_names[i].startswith('EQUALS SIGN') or
unicode_names[i].startswith('FULL STOP') or
unicode_names[i].startswith('HYPHEN-MINUS') or
unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')) :
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else :
return 'Other'
if len(unicode_names) == 0 :
return 'Other'
if len(unicode_names) == 1 :
return unicode_names[0]
for i in range(len(unicode_names)-1) :
if unicode_names[i] != unicode_names[i+1] :
return 'Other'
return unicode_names[0]
def is_ch(char):
if '\u4e00' <= char <= '\u9fa5':
return True
return False
def is_en(char):
if (u'\u0041' <= char <= u'\u005a') or (u'\u0061' <= char <= u'\u007a'):
return True
return False
def merge_en_single2word(text):
new_text = ''
en_word = ''
for word in text.strip().split():
if word is '':
continue
if len(word) > 1 or is_ch(word[0]):
if len(en_word) > 0:
if len(new_text) > 0 and is_en(new_text[-1]) and is_en(new_text[-2]):
new_text += ' '
new_text += en_word
en_word = ''
if is_en(word[0]) and not is_en(word[1]):
new_text += word[0]
word = word[1:]
elif is_en(word[0]) and is_en(word[1]):
word = ' ' + word
elif is_en(word[0]) and (len(new_text) > 0 and is_en(new_text[-1])):
word = ' ' + word
new_text += word
elif len(word) == 1 and is_en(word[0]):
en_word += word
else:
raise KeyError
if len(en_word) > 0:
if len(new_text) > 0 and is_en(new_text[-1]) and is_en(new_text[-2]):
new_text += ' '
new_text += en_word
return new_text
def usage() :
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
if __name__ == '__main__':
if len(sys.argv) == 1 :
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose= 1
padding_symbol= ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
concat_all = False
merge_en_single = False
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose=0
try:
verbose=int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol= ' '
elif b == 'underline':
padding_symbol= '_'
continue
a = '--concat-all='
if sys.argv[1].startswith(a):
concat_all = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--merge_en_single='
if sys.argv[1].startswith(a):
merge_en_single = sys.argv[1][len(a):]
del sys.argv[1]
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
del sys.argv[1]
continue
if not case_sensitive:
ig=set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
if concat_all:
ref_file_cat = ref_file + '.cat'
hyp_file_cat = hyp_file + '.cat'
os.system('rm %s %s' % (ref_file_cat, hyp_file_cat))
ref_cat = {}
hyp_cat = {}
with open(ref_file, 'r') as rr:
for line in rr.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) == 2:
ref_cat[line[0].strip()] = line[1].strip()
else:
ref_cat[line[0].strip()] = ''
with open(hyp_file, 'r') as rh:
for line in rh.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) == 2:
hyp_cat[line[0].strip()] = line[1].strip()
else:
hyp_cat[line[0].strip()] = ''
ref_cat = list(sorted(ref_cat.items(), key=lambda x: int(x[0].split('_')[-1])))
hyp_cat = list(sorted(hyp_cat.items(), key=lambda x: int(x[0].split('_')[-1])))
ref_cat = " ".join([x[1] for x in ref_cat])
hyp_cat = " ".join([x[1] for x in hyp_cat])
with open(ref_file_cat, 'w') as wr, open(hyp_file_cat, 'w') as wh:
wr.write("concat %s" % ref_cat)
wh.write("concat %s" % hyp_cat)
ref_file = ref_file_cat
hyp_file = hyp_file_cat
if merge_en_single:
ref_file_merge = ref_file + '.merge'
hyp_file_merge = hyp_file + '.merge'
os.system('rm %s %s' % (ref_file_merge, hyp_file_merge))
with open(ref_file, 'r') as rr, open(ref_file_merge, 'w') as wr:
for line in rr.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) == 2:
wr.write("%s %s\n" % (line[0].strip(), merge_en_single2word(line[1].strip())))
else:
wr.write("%s %s\n" % (line[0].strip(), ''))
with open(hyp_file, 'r') as rh, open(hyp_file_merge, 'w') as wh:
for line in rh.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) == 2:
wh.write("%s %s\n" % (line[0].strip(), merge_en_single2word(line[1].strip())))
else:
wh.write("%s %s\n" % (line[0].strip(), ''))
ref_file = ref_file_merge
hyp_file = hyp_file_merge
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array)==0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8') :
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array)==0: continue
fid = array[0]
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab :
if word not in default_words :
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters :
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name] :
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('WER: %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])) :
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length-len_lab)
space['rec'].append(length-len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('lab:', end = ' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token = token), end = '')
for n in range(space['lab'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('rec:', end = ' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token = token), end = '')
for n in range(space['rec'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
if verbose:
print('===========================================================================')
print()
result = calculator.overall()
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('Overall -> %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters :
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if len(cluster_file) > 0 : # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8') :
for token in line.decode('utf-8').rstrip('\n').split() :
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else :
cluster.append(token)
print()
print('===========================================================================')
#!/bin/bash
set -x # for better debug view
export PATH=$PWD:$PATH
export PATH=$PWD/../:$PATH
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../:$PYTHONPATH
export LC_ALL=C
THIS_DIR="$( cd "$( dirname "$0" )" && pwd )"
echo "##### config #####"
wav_scp='data/wav.scp'
text='data/text'
write_dir='data/pkg/path'
write_num=1
write_prefix='data.#.list'
text_norm=true
shuffle=true
data_type=raw
num_threads=32
# raw or shard
# parse config
. ${THIS_DIR}/parse_options.sh || exit 1;
mkdir -p $write_dir
process_root=${THIS_DIR}/process_tmp
mkdir -p $process_root
cp $text ${process_root}/text.org
cp $wav_scp ${process_root}/wav.scp.org
# do text normlization
if [ $text_norm = true ]; then
echo "do text normlization"
paste -d " " <(cut -f 1 -d" " ${process_root}/text.org) \
<(cut -f 2- -d" " ${process_root}/text.org \
| tr 'a-z' 'A-Z' | sed 's/\([A-Z]\) \([A-Z]\)/\1▁\2/g' \
| sed 's/\([A-Z]\) \([A-Z]\)/\1▁\2/g' | tr -d " ") \
> ${process_root}/text.process
sed -i 's/\xEF\xBB\xBF//' ${process_root}/text.process
else
cp ${process_root}/text.org ${process_root}/text.process
fi
if [ $data_type = shard ]; then
python3 ${THIS_DIR}/make_shard_list.py --resample 16000 --num_utts_per_shard 100000 \
--num_threads $num_threads ${process_root}/wav.scp.org ${process_root}/text.process $write_dir \
${process_root}/data.list
else
echo "data_type only support shard, but got $data_type" && exit 1
fi
# shuffle
if [ $shuffle = true ]; then
shuf ${process_root}/data.list -o ${process_root}/data.list.shuffle
else
cp ${process_root}/data.list ${process_root}/data.list.shuffle
fi
# split and rename
lines_num=`cat ${process_root}/data.list.shuffle | wc -l`
lines_each=`echo $((lines_num / write_num)) | bc -l`
echo "All samples: $lines_num ; Write for $write_num file ; Each has samples: $lines_each"
mkdir ${process_root}/split
split -l $lines_each -d ${process_root}/data.list.shuffle ${process_root}/split/$write_prefix
i=0
for path in `ls ${process_root}/split | grep $write_prefix`; do
write_file=`echo $write_prefix | sed "s|#|${i}|g"`
cp ${process_root}/split/$path ${write_dir}/$write_file
i=$((i + 1))
done
# remove process dir
rm -r ${THIS_DIR}/process_tmp
#!/usr/bin/env python3
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import logging
import os
import tarfile
import time
import multiprocessing
import torch
import torchaudio
import torchaudio.backend.sox_io_backend as sox
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def write_tar_file(data_list,
no_segments,
tar_file,
resample=16000,
index=0,
total=1):
logging.info('Processing {} {}/{}'.format(tar_file, index, total))
read_time = 0.0
save_time = 0.0
write_time = 0.0
try:
with tarfile.open(tar_file, "w") as tar:
prev_wav = None
for item in data_list:
try:
if no_segments:
key, txt, wav = item
else:
key, txt, wav, start, end = item
suffix = wav.split('.')[-1]
assert suffix in AUDIO_FORMAT_SETS
if no_segments:
ts = time.time()
with open(wav, 'rb') as fin:
data = fin.read()
read_time += (time.time() - ts)
else:
if wav != prev_wav:
ts = time.time()
waveforms, sample_rate = sox.load(wav, normalize=False)
read_time += (time.time() - ts)
prev_wav = wav
start = int(start * sample_rate)
end = int(end * sample_rate)
audio = waveforms[:1, start:end]
# resample
if sample_rate != resample:
if not audio.is_floating_point():
# normalize the audio before resample
# because resample can't process int audio
audio = audio / (1 << 15)
audio = torchaudio.transforms.Resample(
sample_rate, resample)(audio)
audio = (audio * (1 << 15)).short()
else:
audio = torchaudio.transforms.Resample(
sample_rate, resample)(audio)
ts = time.time()
f = io.BytesIO()
sox.save(f, audio, resample, format="wav", bits_per_sample=16)
# Save to wav for segments file
suffix = "wav"
f.seek(0)
data = f.read()
save_time += (time.time() - ts)
assert isinstance(txt, str)
ts = time.time()
txt_file = key + '.txt'
txt = txt.encode('utf8')
txt_data = io.BytesIO(txt)
txt_info = tarfile.TarInfo(txt_file)
txt_info.size = len(txt)
tar.addfile(txt_info, txt_data)
wav_file = key + '.' + suffix
wav_data = io.BytesIO(data)
wav_info = tarfile.TarInfo(wav_file)
wav_info.size = len(data)
tar.addfile(wav_info, wav_data)
write_time += (time.time() - ts)
except Exception as e:
print(e)
continue
logging.info('read {} save {} write {}'.format(read_time, save_time,
write_time))
except Exception as e:
print(e)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--num_utts_per_shard',
type=int,
default=1000,
help='num utts per shard')
parser.add_argument('--num_threads',
type=int,
default=1,
help='num threads for make shards')
parser.add_argument('--prefix',
default='shards',
help='prefix of shards tar file')
parser.add_argument('--segments', default=None, help='segments file')
parser.add_argument('--resample',
type=int,
default=16000,
help='segments file')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('text_file', help='text file')
parser.add_argument('shards_dir', help='output shards dir')
parser.add_argument('shards_list', help='output shards list file')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
torch.set_num_threads(1)
wav_table = {}
with open(args.wav_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
wav_table[arr[0]] = arr[1]
no_segments = True
segments_table = {}
if args.segments is not None:
no_segments = False
with open(args.segments, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 4
segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3]))
data = []
with open(args.text_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split(maxsplit=1)
key = arr[0]
txt = arr[1] if len(arr) > 1 else ''
if no_segments:
if not wav_table.get(key, False):
print("NOTE: do not find %s in wav scp!" % key)
continue
wav = wav_table[key]
data.append((key, txt, wav))
else:
wav_key, start, end = segments_table[key]
wav = wav_table[wav_key]
data.append((key, txt, wav, start, end))
num = args.num_utts_per_shard
chunks = [data[i:i + num] for i in range(0, len(data), num)]
os.makedirs(args.shards_dir, exist_ok=True)
# Using thread pool to speedup
pool = multiprocessing.Pool(processes=args.num_threads)
shards_list = []
tasks_list = []
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
print("chunk {}, length {}".format(i, len(chunk)))
tar_file = os.path.join(args.shards_dir,
'{}_{:09d}.tar'.format(args.prefix, i))
shards_list.append(tar_file)
pool.apply_async(
write_tar_file,
(chunk, no_segments, tar_file, args.resample, i, num_chunks))
pool.close()
pool.join()
with open(args.shards_list, 'w', encoding='utf8') as fout:
for name in shards_list:
fout.write(name + '\n')
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### No we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
# MooER 训练教程
## 目录
1. [简介](#简介)
2. [环境准备](#环境准备)
3. [数据打包](#数据打包)
- [数据准备](#数据准备)
- [打包流程](#打包流程)
4. [模型训练](#模型训练)
- [配置文件](#配置文件)
- [训练流程](#训练流程)
5. [解码与推理](#解码与推理)
- [解码方法](#解码方法)
- [推理示例](#推理示例)
- [结果分析](#结果分析)
---
## 简介
摩耳大模型(英文名:MooER)—— 一个由摩尔线程开发的、基于大语言模型(Large Language Model,LLM)的音频理解模型,包括端到端语音到语音交互、语音到语音翻译、语音到文本翻译和语音识别。我们开源的中文端到端语音交互模型和语音翻译模型(Speech-to-speech)都是基于本训练框架修改后完成,你可以轻易基于本框架完成其他您想完成的音频理解任务的训练和推理。本教程基于MT5K(5000h训练数据规模)的ASR模型的训练参数和配置,囊括了***训练数据处理 -> 模型训练 -> 模型测试*** 。您也可以基于我们开源的8万小时训练的基础模型进行微调。
## 环境准备
参考 https://github.com/MooreThreads/MooER/blob/master/README.md#%EF%B8%8F-build-environtment 部分
## 数据打包
### 数据准备
需要准备你的音频、和对应训练的标签,例如ASR的音频和对应的文本。音频文件应采用 `.wav` 格式,采样率为 16kHz,单通道。`wav.scp` 文件用于记录音频文件的UTTID及其路径,以便在训练时能够正确定位音频文件。
**`wav.scp` 样例:**
```plaintext
BAC009S0764W0121 /nfs2/speech/data/asr/aishell/data_aishell/wav/test/S0764/BAC009S0764W0121.wav
BAC009S0764W0122 /nfs2/speech/data/asr/aishell/data_aishell/wav/test/S0764/BAC009S0764W0122.wav
BAC009S0764W0123 /nfs2/speech/data/asr/aishell/data_aishell/wav/test/S0764/BAC009S0764W0123.wav
BAC009S0764W0124 /nfs2/speech/data/asr/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav
BAC009S0764W0125 /nfs2/speech/data/asr/aishell/data_aishell/wav/test/S0764/BAC009S0764W0125.wav
```
`text` 文件用于记录对应UTTID这条音频的标签,例如ASR标签、AST标签。FAQ标签,等等。我们在ASR的训练中对文本做了正则化,但在AST的训练中保留文本原始的形式。
**`text` 样例:**
```plaintext
BAC009S0764W0121 甚至出现交易几乎停滞的情况
BAC009S0764W0122 一二线城市虽然也处于调整中
BAC009S0764W0123 但因为聚集了过多公共资源
BAC009S0764W0124 为了规避三四线城市明显过剩的市场风险
BAC009S0764W0125 标杆房企必然调整市场战略
```
### 打包流程
输入上述wav.scp和text的路径,我们会对原始音频打包为用于训练和测试的shard文件。如果您用于对测试集打包,我们建议一个单独的测试集只打包为一个shard文件。我们默认100000条音频会写入一个shard,如果需要可以修改 `src/tools/data_package.sh`中的`num_utts_per_shard`参数。
**打包训练集:**
由于训练集通常会很大,用户可以使用 **多进程的方式进行打包,来实现加速**。通过修改 `num_threads` 来实现多进程打包。`text_norm`会对空格和英文进行normalization。
```shell
bash src/tools/data_package.sh \
--wav_scp /nfs1/zhenlin.liang/data/training/wav.scp \
--text /nfs1/zhenlin.liang/data/training/text \
--write_dir /jfs/zhenlin.liang/tmp/pkg_training \
--write_prefix data.list \
--shuffle false \
--text_norm true \
--data_type shard \
--num_threads 32
```
则会生成文件 /jfs/zhenlin.liang/tmp/pkg_training/data.list,其中内容为:
```
/jfs/zhenlin.liang/tmp/aishell_pkg_test/shards_000000000.tar
/jfs/zhenlin.liang/tmp/aishell_pkg_test/shards_000000001.tar
/jfs/zhenlin.liang/tmp/aishell_pkg_test/shards_000000002.tar
/jfs/zhenlin.liang/tmp/aishell_pkg_test/shards_000000003.tar
...
```
**打包测试集:**
```shell
bash src/tools/data_package.sh \
--wav_scp /nfs1/zhenlin.liang/data/testsets_wavscp/test_aishell/wav.scp \
--text /nfs1/zhenlin.liang/data/testsets_wavscp/test_aishell/text \
--write_dir /jfs/zhenlin.liang/tmp/aishell_pkg_test \
--write_prefix data.list \
--shuffle false \
--text_norm true \
--data_type shard \
--num_threads 1
```
则会生成文件 /jfs/zhenlin.liang/tmp/aishell_pkg_test/data.list,其中内容为:
```
/jfs/zhenlin.liang/tmp/aishell_pkg_test/shards_000000000.tar
```
## 模型训练
数据处理完成后,可以愉快的训练你的音频理解大模型了!(我们仅尝试了ASR、AST、S2S等,但实验证明可能大部分音频理解任务都能很好的完成~)
### 配置文件
核心关注两个配置文件:
- src/mooer/configs/asr_config_training.py
- src/mooer/configs/deepspeed_config_zero2.json
`asr_config_training.py` 是用于训练的配置文件 (`training config`)。当然,你也可以在任意位置创建一个新的配置文件,例如 `config.py`。在配置文件中,有一些关键参数是必须配置的:
- **self.llm_path**: `加载的LLM 大模型的路径,例如 pretrained_models/Qwen2-7B-Instruct`
- **self.encoder_path**: `加载的encoder的路径,例如 pretrained_models/paraformer_encoder/paraformer-encoder.pth,可以从我们的huggingface或modelscope下载对应模型`
- **self.output_dir**`保存模型的路径`
- **self.train_data_path**`训练数据的路径,例如上述打包好的 /jfs/zhenlin.liang/tmp/pkg_training/data.list`
另外,还有一些关键参数,例如:
- **self.adapter_path**`如果提供,会加载预训练的adapter,例如可以基于我们的8万小时模型进行微调`
- **self.lora_dir**`如果提供,会加载预训练的Lora权重,例如可以基于我们的8万小时模型进行微调`
- **self.save_merge_rank**`默认设置为True。会将Deepspeed在不同卡保存的权重合并为一个pt文件,同时将Lora weight单独保存。`
- **self.gradient_checkpoint**`如果打开,会对encoder进行梯度重算,可以明显降低whisper的显存占用`
- **self.find_unused_parameters**`encoder参与训练的时候需要设置为True`
- **self.deepspeed_config**`Deepspeed的配置文件,默认为src/mooer/configs/deepspeed_config_zero2.json`
我们的大规模数据的训练(8万小时规模)基于Deepspeed进行训练,我们在`deepspeed_config_zero2.json` 给出了我们的训练配置。当然,你也可以在任意位置创建一个新的配置文件,例如 `ds_config.py`。我们也给出了基于Zero3的配置文件:`deepspeed_config_zero3.json`。你可以在配置文件中修改训练策略、batchsize和一些显存管理机制等。
### 训练流程
修改 `train.sh` 中相关配置,进行训练。你可以通过hostfile来实现多机多卡的训练。修改 `train.sh`中的 `training_config` 为你的训练的配置文件路径。启动训练:
```shell
nohup bash train.sh > your_awesome_training_log 2>&1 &
```
你就能看到对应的训练日志如下:
```plaintext
[2024-08-15 21:08:16][root][INFO] - Training Epoch: 1/10, step 1000 lr 8.996566681120063e-05 completed (loss: 4.442868709564209, acc: 0.2647058963775635)
[2024-08-15 21:10:21][root][INFO] - Training Epoch: 1/10, step 1100 lr 9.134542298314146e-05 completed (loss: 4.4935736656188965, acc: 0.30645161867141724)
[2024-08-15 21:12:21,408] [INFO] [logging.py:96:log_dist] [Rank 0] step=600, skipped=0, lr=[9.26050416794548e-05], mom=[(0.9, 0.999)]
[2024-08-15 21:12:21,418] [INFO] [timer.py:260:stop] epoch=0/micro_step=1200/global_step=600, RunningAvgSamplesPerSec=51.18608989565034, CurrSamplesPerSec=56.79136995208808, MemAllocated=15.78GB, MaxMemAllocated=43.28GB
```
## 解码与推理
### 解码方法
模型训练完成后,我们可以得到以下ckpt。例如在你的output路径下,存在以下文件:
- asr_epoch_1_step_320001_merged
- adapter_project.pt
- new_llm
- adapter_config.json
- adapter_model.bin
- README.md
核心关注一个配置文件:
- src/mooer/configs/asr_config_inference.py
`asr_config_inference.py` 是用于测试的配置文件 (`inference config`)。当然,你也可以在任意位置创建一个新的配置文件,例如 `config.py`。在配置文件中,有一些关键参数是必须配置的:
- **self.llm_path**: `加载的LLM 大模型的路径,例如 pretrained_models/Qwen2-7B-Instruct`
- **self.encoder_path**: `加载的encoder的路径,例如 pretrained_models/paraformer_encoder/paraformer-encoder.pth,可以从我们的huggingface或modelscope下载对应模型`
- **self.adapter_path**: `训练得到的adapter模型路径,例如 asr_epoch_1_step_320001_merged/adapter_project.pt`
- **self.lora_dir**: `训练得到的Lora 权重的路径,例如 asr_epoch_1_step_320001_merged/new_llm`
- **val_batch_size**: `batch解码的参数`
测试集的打包方法参考数据打包部分。一个完整的测试集包括:
- aishell
- data.list
- text
修改 `inference.sh` 来批量的解码你的测试集和计算指标。
- **test_data_dir**: `测试集的目录。例如测试集为/testsets/aishell,则这里为/testsets`
- **test_sets**: `测试集的名字,支持多个测试集,以 / 分割。例如test-clean/test-other`
- **decode_path**: `解码保存的路径`
启动解码:
```shell
nohup bash inference.sh > your_awesome_decode_log 2>&1 &
```
解码完成后,结果如下:
- your_decode_path
- testset1
- text
- wer
- testset2
- text
- wer
### 推理示例
你也可以使用wav文件来直接进行推理,参考 https://github.com/MooreThreads/MooER#-inference
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment