Commit d7cad875 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new files

parents
Pipeline #1560 failed with stages
in 0 seconds
//兼容
window.URL = window.URL || window.webkitURL;
//获取计算机的设备:摄像头或者录音设备
navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia;
var HZRecorder = function (stream, config) {
config = config || {};
config.sampleBits = config.sampleBits || 16; //采样数位 8, 16
config.sampleRate = config.sampleRate || 16000; //采样率 16000
//创建一个音频环境对象
var audioContext = window.AudioContext || window.webkitAudioContext;
var context = new audioContext();
var audioInput = context.createMediaStreamSource(stream);
// 第二个和第三个参数指的是输入和输出都是单声道,2是双声道。
var recorder = context.createScriptProcessor(4096, 2, 2);
var audioData = {
size: 0 //录音文件长度
, buffer: [] //录音缓存
, inputSampleRate: context.sampleRate //输入采样率
, inputSampleBits: 16 //输入采样数位 8, 16
, outputSampleRate: config.sampleRate //输出采样率
, outputSampleBits: config.sampleBits //输出采样数位 8, 16
, input: function (data) {
this.buffer.push(new Float32Array(data));
this.size += data.length;
}
, compress: function () { //合并压缩
//合并
var data = new Float32Array(this.size);
var offset = 0;
for (var i = 0; i < this.buffer.length; i++) {
data.set(this.buffer[i], offset);
offset += this.buffer[i].length;
}
//压缩
var compression = parseInt(this.inputSampleRate / this.outputSampleRate);
var length = data.length / compression;
var result = new Float32Array(length);
var index = 0, j = 0;
while (index < length) {
result[index] = data[j];
j += compression;
index++;
}
return result;
}
, encodeWAV: function () {
var sampleRate = Math.min(this.inputSampleRate, this.outputSampleRate);
var sampleBits = Math.min(this.inputSampleBits, this.outputSampleBits);
var bytes = this.compress();
var dataLength = bytes.length * (sampleBits / 8);
var buffer = new ArrayBuffer(44 + dataLength);
var data = new DataView(buffer);
var channelCount = 1;//单声道
var offset = 0;
var writeString = function (str) {
for (var i = 0; i < str.length; i++) {
data.setUint8(offset + i, str.charCodeAt(i));
}
}
// 资源交换文件标识符
writeString('RIFF');
offset += 4;
// 下个地址开始到文件尾总字节数,即文件大小-8
data.setUint32(offset, 36 + dataLength, true);
offset += 4;
// WAV文件标志
writeString('WAVE');
offset += 4;
// 波形格式标志
writeString('fmt ');
offset += 4;
// 过滤字节,一般为 0x10 = 16
data.setUint32(offset, 16, true);
offset += 4;
// 格式类别 (PCM形式采样数据)
data.setUint16(offset, 1, true);
offset += 2;
// 通道数
data.setUint16(offset, channelCount, true);
offset += 2;
// 采样率,每秒样本数,表示每个通道的播放速度
data.setUint32(offset, sampleRate, true);
offset += 4;
// 波形数据传输率 (每秒平均字节数) 单声道×每秒数据位数×每样本数据位/8
data.setUint32(offset, channelCount * sampleRate * (sampleBits / 8), true);
offset += 4;
// 快数据调整数 采样一次占用字节数 单声道×每样本的数据位数/8
data.setUint16(offset, channelCount * (sampleBits / 8), true);
offset += 2;
// 每样本数据位数
data.setUint16(offset, sampleBits, true);
offset += 2;
// 数据标识符
writeString('data');
offset += 4;
// 采样数据总数,即数据总大小-44
data.setUint32(offset, dataLength, true);
offset += 4;
// 写入采样数据
if (sampleBits === 8) {
for (var i = 0; i < bytes.length; i++, offset++) {
var s = Math.max(-1, Math.min(1, bytes[i]));
var val = s < 0 ? s * 0x8000 : s * 0x7FFF;
val = parseInt(255 / (65535 / (val + 32768)));
data.setInt8(offset, val, true);
}
} else {
for (var i = 0; i < bytes.length; i++, offset += 2) {
var s = Math.max(-1, Math.min(1, bytes[i]));
data.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
}
return new Blob([data], {type: 'audio/wav'});
}
};
//开始录音
this.start = function () {
audioInput.connect(recorder);
recorder.connect(context.destination);
}
//停止
this.stop = function () {
recorder.disconnect();
}
//获取音频文件
this.getBlob = function () {
this.stop();
return audioData.encodeWAV();
}
//回放
this.play = function (audio) {
audio.src = window.URL.createObjectURL(this.getBlob());
}
//清除
this.clear = function () {
audioData.buffer = [];
audioData.size = 0;
}
//上传
this.upload = function (url, callback) {
var fd = new FormData();
// 上传的文件名和数据
fd.append("audio", this.getBlob());
var xhr = new XMLHttpRequest();
xhr.timeout = 60000
if (callback) {
xhr.upload.addEventListener("progress", function (e) {
callback('uploading', e);
}, false);
xhr.addEventListener("load", function (e) {
callback('ok', e);
}, false);
xhr.addEventListener("error", function (e) {
callback('error', e);
}, false);
xhr.addEventListener("abort", function (e) {
callback('cancel', e);
}, false);
}
xhr.open("POST", url);
xhr.send(fd);
}
//音频采集
recorder.onaudioprocess = function (e) {
audioData.input(e.inputBuffer.getChannelData(0));
//record(e.inputBuffer.getChannelData(0));
}
};
//抛出异常
HZRecorder.throwError = function (message) {
alert(message);
throw new function () {
this.toString = function () {
return message;
}
}
}
//是否支持录音
HZRecorder.canRecording = (navigator.getUserMedia != null);
//获取录音机
HZRecorder.get = function (callback, config) {
if (callback) {
if (navigator.getUserMedia) {
navigator.getUserMedia(
{audio: true} //只启用音频
, function (stream) {
var rec = new HZRecorder(stream, config);
callback(rec);
}
, function (error) {
switch (error.code || error.name) {
case 'PERMISSION_DENIED':
case 'PermissionDeniedError':
HZRecorder.throwError('用户拒绝提供信息。');
break;
case 'NOT_SUPPORTED_ERROR':
case 'NotSupportedError':
HZRecorder.throwError('浏览器不支持硬件设备。');
break;
case 'MANDATORY_UNSATISFIED_ERROR':
case 'MandatoryUnsatisfiedError':
HZRecorder.throwError('无法发现指定的硬件设备。');
break;
default:
HZRecorder.throwError('无法打开麦克风。异常信息:' + (error.code || error.name));
break;
}
});
} else {
window.alert('不是HTTPS协议或者localhost地址,不能使用录音功能!')
HZRecorder.throwErr('当前浏览器不支持录音功能。');
return;
}
}
};
\ No newline at end of file
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>语音识别-夜雨飘零</title>
<script type="text/javascript" src="/static/record.js"></script>
<link href="/static/index.css" rel="stylesheet" type="text/css"/>
</head>
<body>
<div id="header">
<h1>夜雨飘零语音识别</h1>
</div>
<div id="content">
<div>
<a id="upload" onclick="uploadAudioFile()" class="file">选择音频文件</a>
<a id="play_btn" onclick="uploadRecordAudio()" class="file">上传录音</a>
<audio controls autoplay></audio>
<img id="record_btn" onclick="record()" src="/static/record.png" alt="录音"/>
</div>
<div id="result">
<label for="result_p"></label><textarea id="result_p"></textarea>
</div>
上传进度:<progress id="progress1" value="0" max="100"></progress>&nbsp;<text id="progress_text"></text>
</div>
<script>
let is_recording = false;
let is_playing = false;
let host = location.origin;
let recorder;
let audio = document.querySelector('audio');
let textarea = document.getElementById('result_p')
let progress1 = document.getElementById('progress1')
let progress_text = document.getElementById('progress_text')
function record() {
if (is_recording) {
is_recording = false;
stopRecording()
document.getElementById('record_btn').src = '/static/record.png'
startPlay();
stopPlay();
} else {
is_recording = true;
startRecording()
document.getElementById('record_btn').src = '/static/recording.gif'
}
}
function play() {
if (is_playing) {
is_playing = false;
stopPlay()
document.getElementById('play_btn').innerText = '播放音频'
} else {
is_playing = true;
startPlay()
document.getElementById('play_btn').innerText = '停止播放'
}
}
function startRecording() {
HZRecorder.get(function (rec) {
recorder = rec;
recorder.start();
});
}
function stopRecording() {
recorder.stop();
}
function startPlay() {
recorder.play(audio);
}
function stopPlay() {
audio.pause();
}
function cancelAudio() {
recorder.stop();
recorder.clear();
}
function uploadRecordAudio() {
recorder.upload(host + "/recognition", function (state, e) {
switch (state) {
case 'uploading':
const percentComplete = Math.round(e.loaded * 100 / e.total);
console.log(percentComplete + '%');
// 弹出进度条
progress1.value = percentComplete
progress_text.innerText = percentComplete + '%'
break;
case 'ok':
console.log(e.target.responseText)
textarea.value = e.target.responseText
break;
case 'error':
alert("上传失败");
break;
case 'cancel':
alert("上传被取消");
break;
}
});
}
// 上传音频文件
function uploadAudioFile() {
const input = document.createElement("input");
input.type = "file";
input.accept = "audio/*,video/*";
input.click();
input.onchange = function () {
const file = input.files[0];
console.log(file)
audio.src = window.URL.createObjectURL(file);
stopPlay();
upload_file(host + "/recognition", file, function (state, e) {
switch (state) {
case 'uploading':
const percentComplete = Math.round(e.loaded * 100 / e.total);
console.log(percentComplete + '%');
// 弹出进度条
progress1.value = percentComplete
progress_text.innerText = percentComplete + '%'
break;
case 'ok':
console.log(e.target.responseText)
textarea.value = e.target.responseText
break;
case 'error':
alert("上传失败");
break;
case 'cancel':
alert("上传被取消");
break;
}
});
}
}
// 上传音频文件
upload_file = function (url, file, callback) {
const fd = new FormData();
// 上传的文件名和数据
fd.append("audio", file);
const xhr = new XMLHttpRequest();
xhr.timeout = 60000
if (callback) {
xhr.upload.addEventListener("progress", function (e) {
callback('uploading', e);
}, false);
xhr.addEventListener("load", function (e) {
callback('ok', e);
}, false);
xhr.addEventListener("error", function (e) {
callback('error', e);
}, false);
xhr.addEventListener("abort", function (e) {
callback('cancel', e);
}, false);
}
xhr.open("POST", url);
xhr.send(fd);
}
</script>
</body>
</html>
\ No newline at end of file
import argparse
import functools
import os
import sys
import time
import soundfile
from faster_whisper import WhisperModel
from tqdm import tqdm
sys.path.insert(0, sys.path[0] + "/../")
from utils.utils import print_arguments, add_arguments
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("audio_path", type=str, default="../dataset/test_long.wav", help="预测的音频路径")
add_arg("model_path", type=str, default="../models/whisper-tiny-ct2", help="转换后的模型路径,转换方式看文档")
add_arg("use_gpu", type=bool, default=True, help="是否使用gpu进行预测")
add_arg("infer_num", type=int, default=10, help="预测的次数,不包括预热")
add_arg("use_int8", type=bool, default=False, help="是否使用int8进行预测")
add_arg("beam_size", type=int, default=1, help="解码搜索大小")
add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载")
args = parser.parse_args()
print_arguments(args)
# 检查模型文件是否存在
assert os.path.exists(args.model_path), f"模型文件{args.model_path}不存在"
# 加载模型
if args.use_gpu:
if not args.use_int8:
model = WhisperModel(args.model_path, device="cuda", compute_type="float16",
local_files_only=args.local_files_only)
else:
model = WhisperModel(args.model_path, device="cuda", compute_type="int8_float16",
local_files_only=args.local_files_only)
else:
model = WhisperModel(args.model_path, device="cpu", compute_type="int8",
local_files_only=args.local_files_only)
# 支持large-v3模型
if 'large-v3' in args.model_path:
model.feature_extractor.mel_filters = \
model.feature_extractor.get_mel_filters(model.feature_extractor.sampling_rate,
model.feature_extractor.n_fft, n_mels=128)
sample, sr = soundfile.read(args.audio_path)
# 预热
_, _ = model.transcribe(sample.copy())
start_time = time.time()
# 语音识别
for i in tqdm(range(args.infer_num)):
segments, info = model.transcribe(sample.copy(), beam_size=args.beam_size)
for segment in segments:
_ = segment.text
print(f"音频时长:{int(len(sample) / sr)}s,预测平均耗时:{((time.time() - start_time) / args.infer_num):.3f}s")
import argparse
import functools
import platform
import sys
import time
import soundfile
import torch
from tqdm import tqdm
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForCausalLM
sys.path.insert(0, sys.path[0] + "/../")
from utils.utils import print_arguments, add_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("audio_path", type=str, default="../dataset/test_long.wav", help="预测的音频路径")
add_arg("model_path", type=str, default="../base-model/openai/whisper-tiny", help="合并模型的路径,或者是huggingface上模型的名称")
add_arg("use_gpu", type=bool, default=True, help="是否使用gpu进行预测")
add_arg("num_beams", type=int, default=1, help="解码搜索大小")
add_arg("infer_num", type=int, default=10, help="预测的次数,不包括预热")
add_arg("batch_size", type=int, default=16, help="预测batch_size大小")
add_arg("use_compile", type=bool, default=False, help="是否使用Pytorch2.0的编译器")
add_arg("assistant_model_path", type=str, default=None, help="助手模型,可以提高推理速度,例如openai/whisper-tiny")
add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载")
add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速")
add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速")
args = parser.parse_args()
print_arguments(args)
# 设置设备
device = "cuda:0" if torch.cuda.is_available() and args.use_gpu else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() and args.use_gpu else torch.float32
# 获取Whisper的特征提取器、编码器和解码器
processor = AutoProcessor.from_pretrained(args.model_path)
# 获取模型
model = AutoModelForSpeechSeq2Seq.from_pretrained(
args.model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True,
use_flash_attention_2=args.use_flash_attention_2
)
if args.use_bettertransformer and not args.use_flash_attention_2:
model = model.to_bettertransformer()
# 使用Pytorch2.0的编译器
if args.use_compile:
if torch.__version__ >= "2" and platform.system().lower() != 'windows':
model = torch.compile(model)
model.to(device)
# 获取助手模型
generate_kwargs_pipeline = None
if args.assistant_model_path is not None:
assistant_model = AutoModelForCausalLM.from_pretrained(
args.assistant_model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
generate_kwargs_pipeline = {"assistant_model": assistant_model}
# 获取管道
infer_pipe = pipeline("automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=args.batch_size,
torch_dtype=torch_dtype,
generate_kwargs=generate_kwargs_pipeline,
device=device)
sample, sr = soundfile.read(args.audio_path)
# 预热
_ = infer_pipe(sample.copy())
start_time = time.time()
for i in tqdm(range(args.infer_num)):
_ = infer_pipe(sample.copy(), generate_kwargs={"task": "transcribe", "num_beams": args.num_beams})
print(f"音频时长:{int(len(sample) / sr)}s,预测平均耗时:{((time.time() - start_time) / args.infer_num):.3f}s")
import argparse
import json
import logging
import math
import multiprocessing
import os
import sys
from multiprocessing import cpu_count
import ijson
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from tqdm import tqdm
sys.path.insert(0, sys.path[0] + "/../")
from utils.binary import DatasetWriter
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--wenetspeech_json', type=str, default='/media/WenetSpeech数据集/WenetSpeech.json',
help="WenetSpeech的标注json文件路径")
parser.add_argument('--add_pun', type=bool, default=True, help="是否添加标点符")
parser.add_argument('--annotation_dir', type=str, default='../dataset/', help="存放数据列表的文件夹路径")
args = parser.parse_args()
if not os.path.exists(args.annotation_dir):
os.makedirs(args.annotation_dir)
# 训练、测试数据列表
train_list_path = os.path.join(args.annotation_dir, 'train_wenet.json')
test_net_path = os.path.join(args.annotation_dir, 'test_net.json')
test_meeting_path = os.path.join(args.annotation_dir, 'test_meeting.json')
# 获取标注信息
def get_data(wenetspeech_json):
data_list = []
input_dir = os.path.dirname(wenetspeech_json)
i = 0
# 开始读取数据,因为文件太大,无法获取进度
with open(wenetspeech_json, 'r', encoding='utf-8') as f:
objects = ijson.items(f, 'audios.item')
print("开始读取数据")
while True:
try:
long_audio = objects.__next__()
i += 1
try:
long_audio_path = os.path.realpath(os.path.join(input_dir, long_audio['path']))
aid = long_audio['aid']
segments_lists = long_audio['segments']
assert (os.path.exists(long_audio_path))
except AssertionError:
print(f'''Warning: {long_audio_path} 不存在或者已经处理过自动删除了,跳过''')
continue
except Exception:
print(f'''Warning: {aid} 数据读取错误,跳过''')
continue
else:
data_list.append([long_audio_path.replace('\\', '/'), segments_lists])
except StopIteration:
print("数据读取完成")
break
return data_list
def main():
f_train = open(train_list_path, 'w', encoding='utf-8')
f_test_net = open(test_net_path, 'w', encoding='utf-8')
f_test_meeting = open(test_meeting_path, 'w', encoding='utf-8')
all_data = get_data(args.wenetspeech_json)
print(f'总数据量为:{len(all_data)}')
for data in tqdm(all_data):
long_audio_path, segments_lists = data
for segment_file in segments_lists:
start_time = float(segment_file['begin_time'])
end_time = float(segment_file['end_time'])
text = segment_file['text']
confidence = segment_file['confidence']
if confidence < 0.95: continue
line = dict(audio={"path": long_audio_path,
"start_time": round(start_time, 3),
"end_time": round(end_time, 3)},
sentence=text,
duration=round(end_time - start_time, 3))
data_type = long_audio_path.split('/')[-4]
if data_type == 'test_net':
f_test_net.write(json.dumps(line, ensure_ascii=False) + '\n')
if data_type == 'test_meeting':
f_test_meeting.write(json.dumps(line, ensure_ascii=False) + '\n')
if data_type == 'train':
f_train.write(json.dumps(line, ensure_ascii=False) + '\n')
f_train.close()
f_test_meeting.close()
f_test_net.close()
# 合并多条音频,增加时间戳,同时加速训练
def merge_list():
for file_path in [train_list_path, test_net_path, test_meeting_path]:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
with open(file_path, 'w', encoding='utf-8') as f:
sentences = []
duration = 0
start_time = 0
text = ''
for i in tqdm(range(len(lines))):
data = json.loads(lines[i])
sentence = data["sentence"]
# 新数据
if duration == 0:
start_time = data['audio']["start_time"]
duration = data['audio']["end_time"] - start_time
# 带时间戳数据
sentences.append({"start": round(data['audio']["start_time"] - start_time, 2),
"end": round(data['audio']['end_time'] - start_time, 2),
"text": sentence})
text += sentence
name = data['audio']['path']
if i < len(lines) - 2:
next_data = json.loads(lines[i + 1])
next_name = next_data['audio']['path']
next_end_time = next_data['audio']["end_time"]
# 如果下一条数据是新数据或者加上就大于30秒,就写入数据
if next_name != name or next_end_time - start_time >= 30:
data1 = dict()
data1['audio'] = {"path": data['audio']['path']}
data1['audio']['start_time'] = start_time
data1['audio']['end_time'] = data['audio']['end_time']
data1['duration'] = round(data['audio']['end_time'] - start_time, 2)
data1['sentence'] = text
data1['sentences'] = sentences
f.write(f'{json.dumps(data1, ensure_ascii=False)}\n')
sentences = []
duration = 0
start_time = 0
text = ''
else:
# 最后一条数据处理方式
data1 = dict()
data1['audio'] = {"path": data['audio']['path']}
data1['audio']['start_time'] = start_time
data1['audio']['end_time'] = data['audio']['end_time']
data1['duration'] = round(data['audio']['end_time'] - start_time, 2)
data1['sentence'] = text
data1['sentences'] = sentences
f.write(f'{json.dumps(data1, ensure_ascii=False)}\n')
sentences = []
duration = 0
start_time = 0
text = ''
# 设置空白音频和转换格式
def process_audio(data, i):
for path, sentences in tqdm(data, desc=f"处理进程{i}"):
if not os.path.exists(path): continue
save_path = path[:-5] + '.flac'
if os.path.exists(save_path): continue
sample, sr = soundfile.read(path)
for sentence in sentences:
start, end = sentence
start = max(int((start + 0.1) * sr), 0)
end = min(int((end - 0.1) * sr), len(sample))
sample[start:end] = 0
soundfile.write(save_path, sample, sr)
# 设置没有标注的位置静音
def set_silence():
for file_path in [train_list_path, test_net_path]:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
all_data = {}
for line in tqdm(lines, desc='读取数据列表'):
data = json.loads(line)
path = data['audio']['path']
if os.path.splitext(path)[-1] != '.opus': continue
start_a = data['audio']['start_time']
sentences = data['sentences']
last_end = start_a
for sentence in sentences:
start = round(start_a + sentence['start'], 3)
if start - last_end > 1:
if path in all_data.keys():
all_data[path].append([last_end, start])
else:
all_data[path] = [[last_end, start]]
else:
if path not in all_data.keys():
all_data[path] = []
last_end = round(start_a + sentence['end'], 3)
# 多进程处理数据
all_data = list(all_data.items())
num_worker = cpu_count()
length = math.ceil(len(all_data) / num_worker)
data = [all_data[i * length:(i + 1) * length] for i in range(num_worker)]
my_process = []
for i in range(num_worker):
process = multiprocessing.Process(target=process_audio, args=(data[i], i))
my_process.append(process)
for process in my_process:
process.start()
for process in my_process:
process.join()
# 修改路径,因为是转成flac了
with open(file_path, 'w', encoding='utf-8') as f:
for line in tqdm(lines, desc='修改路径后缀'):
data = json.loads(line)
path = data['audio']['path']
path = path.replace('.opus', '.flac')
if not os.path.exists(path):
print(f'{path}文件不存在', file=sys.stderr)
continue
data['audio']['path'] = path
f.write(json.dumps(data, ensure_ascii=False) + '\n')
# 添加标点符号
def process_pun(data, i):
inference_pipline = pipeline(task=Tasks.punctuation,
model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
model_revision="v1.0.0")
f = open(f'temp{i}.txt', 'w', encoding='utf-8')
for line in tqdm(data, desc=f"处理进程{i}"):
data = json.loads(line)
sentence = data['sentence']
sentence = sentence.replace(',', '').replace('。', '').replace('?', '').replace('!', '').replace('、', '')
sentence = inference_pipline(text_in=sentence)['text']
data['sentence'] = sentence
param_dict = {"cache": []}
sentences = data['sentences']
for i in range(len(sentences)):
text = sentences[i]['text']
text = text.replace(',', '').replace('。', '').replace('?', '').replace('!', '').replace('、', '')
text = inference_pipline(text_in=text, param_dict=param_dict)['text']
sentences[i]['text'] = text
f.write(json.dumps(data, ensure_ascii=False) + '\n')
# 多进程添加标点符号
def add_pun():
for file_path in [train_list_path, test_net_path, test_meeting_path]:
with open(file_path, 'r', encoding='utf-8') as f:
all_data = f.readlines()
# 多进程添加标点符号,根据自己的显存大小调整
num_worker = 4
length = math.ceil(len(all_data) / num_worker)
data = [all_data[i * length:(i + 1) * length] for i in range(num_worker)]
my_process = []
for i in range(num_worker):
process = multiprocessing.Process(target=process_pun, args=(data[i], i))
my_process.append(process)
for process in my_process:
process.start()
for process in my_process:
process.join()
# 合并文件
with open(file_path, 'w', encoding='utf-8') as fw:
for i in range(num_worker):
with open(f'temp{i}.txt', 'r', encoding='utf-8') as fr:
lines = fr.readlines()
for line in lines:
fw.write(line)
# 转成二进制文件,减少内存占用
def create_binary():
print('正在把数据列表转成二进制文件...')
dataset_writer = DatasetWriter(f"{args.annotation_dir}/train")
with open(train_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in tqdm(lines):
line = line.replace('\n', '')
dataset_writer.add_data(line)
dataset_writer.close()
if __name__ == '__main__':
main()
# 合并多条音频,增加时间戳,同时加速训练
merge_list()
# 设置没有标注的位置静音
set_silence()
# 添加标点符号
if args.add_pun:
add_pun()
# 转成二进制文件,减少内存占用
create_binary()
#!/bin/bash
# Transformer模型
python compute_speed_tf.py --model_path=openai/whisper-tiny
python compute_speed_tf.py --model_path=openai/whisper-tiny --use_compile=True
python compute_speed_tf.py --model_path=openai/whisper-tiny --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-tiny --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-tiny --use_compile=True --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-tiny --use_compile=True --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-base
python compute_speed_tf.py --model_path=openai/whisper-base --use_compile=True
python compute_speed_tf.py --model_path=openai/whisper-base --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-base --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-base --use_compile=True --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-base --use_compile=True --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-small
python compute_speed_tf.py --model_path=openai/whisper-small --use_compile=True
python compute_speed_tf.py --model_path=openai/whisper-small --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-small --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-small --use_compile=True --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-small --use_compile=True --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-medium
python compute_speed_tf.py --model_path=openai/whisper-medium --use_compile=True
python compute_speed_tf.py --model_path=openai/whisper-medium --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-medium --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-medium --use_compile=True --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-medium --use_compile=True --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-large-v2
python compute_speed_tf.py --model_path=openai/whisper-large-v2 --use_compile=True
python compute_speed_tf.py --model_path=openai/whisper-large-v2 --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-large-v2 --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-large-v2 --use_compile=True --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-large-v2 --use_compile=True --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-large-v3
python compute_speed_tf.py --model_path=openai/whisper-large-v3 --use_compile=True
python compute_speed_tf.py --model_path=openai/whisper-large-v3 --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-large-v3 --use_flash_attention_2=True
python compute_speed_tf.py --model_path=openai/whisper-large-v3 --use_compile=True --use_bettertransformer=True
python compute_speed_tf.py --model_path=openai/whisper-large-v3 --use_compile=True --use_flash_attention_2=True
# Ctranslate2模型
python compute_speed_ct2.py --model_path=../models/whisper-tiny-ct2/
python compute_speed_ct2.py --model_path=../models/whisper-base-ct2/
python compute_speed_ct2.py --model_path=../models/whisper-small-ct2/
python compute_speed_ct2.py --model_path=../models/whisper-medium-ct2/
python compute_speed_ct2.py --model_path=../models/whisper-large-v2-ct2/
python compute_speed_ct2.py --model_path=../models/whisper-large-v3-ct2/
python compute_speed_ct2.py --model_path=../models/whisper-tiny-ct2/ --use_int8=True
python compute_speed_ct2.py --model_path=../models/whisper-base-ct2/ --use_int8=True
python compute_speed_ct2.py --model_path=../models/whisper-small-ct2/ --use_int8=True
python compute_speed_ct2.py --model_path=../models/whisper-medium-ct2/ --use_int8=True
python compute_speed_ct2.py --model_path=../models/whisper-large-v2-ct2/ --use_int8=True
python compute_speed_ct2.py --model_path=../models/whisper-large-v3-ct2/ --use_int8=True
import json
import mmap
import struct
from tqdm import tqdm
class DatasetWriter(object):
def __init__(self, prefix):
# 创建对应的数据文件
self.data_file = open(prefix + '.data', 'wb')
self.header_file = open(prefix + '.header', 'wb')
self.data_sum = 0
self.offset = 0
self.header = ''
def add_data(self, data):
key = str(self.data_sum)
data = bytes(data, encoding="utf8")
# 写入图像数据
self.data_file.write(struct.pack('I', len(key)))
self.data_file.write(key.encode('ascii'))
self.data_file.write(struct.pack('I', len(data)))
self.data_file.write(data)
# 写入索引
self.offset += 4 + len(key) + 4
self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n'
self.header_file.write(self.header.encode('ascii'))
self.offset += len(data)
self.data_sum += 1
def close(self):
self.data_file.close()
self.header_file.close()
class DatasetReader(object):
def __init__(self, data_header_path, min_duration=0, max_duration=30):
self.keys = []
self.offset_dict = {}
self.fp = open(data_header_path.replace('.header', '.data'), 'rb')
self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'):
key, val_pos, val_len = line.split('\t'.encode('ascii'))
data = self.m[int(val_pos):int(val_pos) + int(val_len)]
data = str(data, encoding="utf-8")
data = json.loads(data)
# 跳过超出长度限制的音频
if data["duration"] < min_duration:
continue
if max_duration != -1 and data["duration"] > max_duration:
continue
self.keys.append(key)
self.offset_dict[key] = (int(val_pos), int(val_len))
# 获取一行列表数据
def get_data(self, key):
p = self.offset_dict.get(key, None)
if p is None:
return None
val_pos, val_len = p
data = self.m[val_pos:val_pos + val_len]
data = str(data, encoding="utf-8")
return json.loads(data)
# 获取keys
def get_keys(self):
return self.keys
def __len__(self):
return len(self.keys)
import os
import os
import shutil
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
# 保存模型时的回调函数
class SavePeftModelCallback(TrainerCallback):
def on_save(self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs, ):
if args.local_rank == 0 or args.local_rank == -1:
# 保存效果最好的模型
best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best")
# 因为只保存最新5个检查点,所以要确保不是之前的检查点
if os.path.exists(state.best_model_checkpoint):
if os.path.exists(best_checkpoint_folder):
shutil.rmtree(best_checkpoint_folder)
shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
print(f"效果最好的检查点为:{state.best_model_checkpoint},评估结果为:{state.best_metric}")
return control
import re
from dataclasses import dataclass
from typing import Any, List, Dict, Union
import torch
from zhconv import convert
# 删除标点符号
def remove_punctuation(text: str or List[str]):
punctuation = '!,.;:?、!,。;:?'
if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
result_text.append(t)
return result_text
else:
raise Exception(f'不支持该类型{type(text)}')
# 将繁体中文总成简体中文
def to_simple(text: str or List[str]):
if isinstance(text, str):
text = convert(text, 'zh-cn')
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, 'zh-cn')
result_text.append(t)
return result_text
else:
raise Exception(f'不支持该类型{type(text)}')
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"][0]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
import bitsandbytes as bnb
import torch
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def find_all_linear_names(use_8bit, model):
cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
target_modules = list(lora_module_names)
return target_modules
def load_from_checkpoint(resume_from_checkpoint, model=None):
pass
import json
import os
import random
import sys
from typing import List
import librosa
import numpy as np
import soundfile
from torch.utils.data import Dataset
from tqdm import tqdm
from utils.binary import DatasetReader
class CustomDataset(Dataset):
def __init__(self,
data_list_path,
processor,
mono=True,
language=None,
timestamps=False,
sample_rate=16000,
min_duration=0.5,
max_duration=30,
augment_config_path=None):
"""
Args:
data_list_path: 数据列表文件的路径,或者二进制列表的头文件路径
processor: Whisper的预处理工具,WhisperProcessor.from_pretrained获取
mono: 是否将音频转换成单通道,这个必须是True
language: 微调数据的语言
timestamps: 微调时是否使用时间戳
sample_rate: 音频的采样率,默认是16000
min_duration: 小于这个时间段的音频将被截断,单位秒,不能小于0.5,默认0.5s
max_duration: 大于这个时间段的音频将被截断,单位秒,不能大于30,默认30s
augment_config_path: 数据增强配置参数文件路径
"""
super(CustomDataset, self).__init__()
assert min_duration >= 0.5, f"min_duration不能小于0.5,当前为:{min_duration}"
assert max_duration <= 30, f"max_duration不能大于30,当前为:{max_duration}"
self.data_list_path = data_list_path
self.processor = processor
self.data_list_path = data_list_path
self.sample_rate = sample_rate
self.mono = mono
self.language = language
self.timestamps = timestamps
self.min_duration = min_duration
self.max_duration = max_duration
self.vocab = self.processor.tokenizer.get_vocab()
self.startoftranscript = self.vocab['<|startoftranscript|>']
self.endoftext = self.vocab['<|endoftext|>']
if '<|nospeech|>' in self.vocab.keys():
self.nospeech = self.vocab['<|nospeech|>']
self.timestamp_begin = None
else:
# 兼容旧模型
self.nospeech = self.vocab['<|nocaptions|>']
self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1
self.data_list: List[dict] = []
# 加载数据列表
self._load_data_list()
# 数据增强配置参数
self.augment_configs = None
self.noises_path = None
self.speed_rates = None
if augment_config_path:
with open(augment_config_path, 'r', encoding='utf-8') as f:
self.augment_configs = json.load(f)
# 加载数据列表
def _load_data_list(self):
if self.data_list_path.endswith(".header"):
# 获取二进制的数据列表
self.dataset_reader = DatasetReader(data_header_path=self.data_list_path,
min_duration=self.min_duration,
max_duration=self.max_duration)
self.data_list = self.dataset_reader.get_keys()
else:
# 获取数据列表
with open(self.data_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
self.data_list = []
for line in tqdm(lines, desc='读取数据列表'):
if isinstance(line, str):
line = json.loads(line)
if not isinstance(line, dict): continue
# 跳过超出长度限制的音频
if line["duration"] < self.min_duration:
continue
if self.max_duration != -1 and line["duration"] > self.max_duration:
continue
self.data_list.append(dict(line))
# 从数据列表里面获取音频数据、采样率和文本
def _get_list_data(self, idx):
if self.data_list_path.endswith(".header"):
data_list = self.dataset_reader.get_data(self.data_list[idx])
else:
data_list = self.data_list[idx]
# 分割音频路径和标签
audio_file = data_list["audio"]['path']
transcript = data_list["sentences"] if self.timestamps else data_list["sentence"]
language = data_list["language"] if 'language' in data_list.keys() else None
if 'start_time' not in data_list["audio"].keys():
sample, sample_rate = soundfile.read(audio_file, dtype='float32')
else:
start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"]
# 分割读取音频
sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time)
sample = sample.T
# 转成单通道
if self.mono:
sample = librosa.to_mono(sample)
# 数据增强
if self.augment_configs:
sample, sample_rate = self.augment(sample, sample_rate)
# 重采样
if self.sample_rate != sample_rate:
sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate)
return sample, sample_rate, transcript, language
def _load_timestamps_transcript(self, transcript: List[dict]):
assert isinstance(transcript, list), f"transcript应该为list,当前为:{type(transcript)}"
data = dict()
labels = self.processor.tokenizer.prefix_tokens[:3]
for t in transcript:
# 将目标文本编码为标签ID
start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01
if self.timestamp_begin is None:
start = self.vocab[f'<|{start:.2f}|>']
else:
start = self.timestamp_begin + round(start * 100) // 2
end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01
if self.timestamp_begin is None:
end = self.vocab[f'<|{end:.2f}|>']
else:
end = self.timestamp_begin + round(end * 100) // 2
label = self.processor(text=t['text']).input_ids[4:-1]
labels.extend([start])
labels.extend(label)
labels.extend([end])
data['labels'] = labels + [self.endoftext]
return data
def __getitem__(self, idx):
try:
# 从数据列表里面获取音频数据、采样率和文本
sample, sample_rate, transcript, language = self._get_list_data(idx=idx)
# 可以为单独数据设置语言
self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language)
if len(transcript) > 0:
# 加载带有时间戳的文本
if self.timestamps:
data = self._load_timestamps_transcript(transcript=transcript)
# 从输入音频数组中计算log-Mel输入特征
data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features
else:
# 获取log-Mel特征和标签ID
data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript)
else:
# 如果没有文本,则使用<|nospeech|>标记
data = self.processor(audio=sample, sampling_rate=self.sample_rate)
data['labels'] = [self.startoftranscript, self.nospeech, self.endoftext]
return data
except Exception as e:
print(f'读取数据出错,序号:{idx},错误信息:{e}', file=sys.stderr)
return self.__getitem__(random.randint(0, self.__len__() - 1))
def __len__(self):
return len(self.data_list)
# 分割读取音频
@staticmethod
def slice_from_file(file, start, end):
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = round(float(len(sndfile)) / sample_rate, 3)
start = round(start, 3)
end = round(end, 3)
# 从末尾开始计
if start < 0.0: start += duration
if end < 0.0: end += duration
# 保证数据不越界
if start < 0.0: start = 0.0
if end > duration: end = duration
if end < 0.0:
raise ValueError("切片结束位置(%f s)越界" % end)
if start > end:
raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
sample = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return sample, sample_rate
# 数据增强
def augment(self, sample, sample_rate):
for config in self.augment_configs:
if config['type'] == 'speed' and random.random() < config['prob']:
if self.speed_rates is None:
min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \
config['params']['max_speed_rate'], config['params']['num_rates']
self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True)
rate = random.choice(self.speed_rates)
sample = self.change_speed(sample, speed_rate=rate)
if config['type'] == 'shift' and random.random() < config['prob']:
min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms']
shift_ms = random.randint(min_shift_ms, max_shift_ms)
sample = self.shift(sample, sample_rate, shift_ms=shift_ms)
if config['type'] == 'volume' and random.random() < config['prob']:
min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS']
gain = random.randint(min_gain_dBFS, max_gain_dBFS)
sample = self.volume(sample, gain=gain)
if config['type'] == 'resample' and random.random() < config['prob']:
new_sample_rates = config['params']['new_sample_rates']
new_sample_rate = np.random.choice(new_sample_rates)
sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate)
sample_rate = new_sample_rate
if config['type'] == 'noise' and random.random() < config['prob']:
min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB']
if self.noises_path is None:
self.noises_path = []
noise_dir = config['params']['noise_dir']
if os.path.exists(noise_dir):
for file in os.listdir(noise_dir):
self.noises_path.append(os.path.join(noise_dir, file))
noise_path = random.choice(self.noises_path)
snr_dB = random.randint(min_snr_dB, max_snr_dB)
sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB)
return sample, sample_rate
# 改变语速
@staticmethod
def change_speed(sample, speed_rate):
if speed_rate == 1.0:
return sample
if speed_rate <= 0:
raise ValueError("速度速率应大于零")
old_length = sample.shape[0]
new_length = int(old_length / speed_rate)
old_indices = np.arange(old_length)
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
sample = np.interp(new_indices, old_indices, sample).astype(np.float32)
return sample
# 音频偏移
@staticmethod
def shift(sample, sample_rate, shift_ms):
duration = sample.shape[0] / sample_rate
if abs(shift_ms) / 1000.0 > duration:
raise ValueError("shift_ms的绝对值应该小于音频持续时间")
shift_samples = int(shift_ms * sample_rate / 1000)
if shift_samples > 0:
sample[:-shift_samples] = sample[shift_samples:]
sample[-shift_samples:] = 0
elif shift_samples < 0:
sample[-shift_samples:] = sample[:shift_samples]
sample[:-shift_samples] = 0
return sample
# 改变音量
@staticmethod
def volume(sample, gain):
sample *= 10.**(gain / 20.)
return sample
# 声音重采样
@staticmethod
def resample(sample, orig_sr, target_sr):
sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr)
return sample
# 添加噪声
def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0):
noise_sample, sr = librosa.load(noise_path, sr=sample_rate)
# 标准化音频音量,保证噪声不会太大
target_db = -20
gain = min(max_gain_db, target_db - self.rms_db(sample))
sample *= 10. ** (gain / 20.)
# 指定噪声音量
sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample)
noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db)
noise_sample *= 10. ** (noise_gain_db / 20.)
# 固定噪声长度
if noise_sample.shape[0] < sample.shape[0]:
diff_duration = sample.shape[0] - noise_sample.shape[0]
noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap')
elif noise_sample.shape[0] > sample.shape[0]:
start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0])
noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame]
sample += noise_sample
return sample
@staticmethod
def rms_db(sample):
mean_square = np.mean(sample ** 2)
return 10 * np.log10(mean_square)
\ No newline at end of file
import hashlib
import os
import tarfile
import urllib.request
from tqdm import tqdm
def print_arguments(args):
print("----------- Configuration Arguments -----------")
for arg, value in vars(args).items():
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def strtobool(val):
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
else:
raise ValueError("invalid truth value %r" % (val,))
def str_none(val):
if val == 'None':
return None
else:
return val
def add_arguments(argname, type, default, help, argparser, **kwargs):
type = strtobool if type == bool else type
type = str_none if type == str else type
argparser.add_argument("--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print(f"Downloading {url} to {filepath} ...")
with urllib.request.urlopen(url) as source, open(filepath, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
print(f"\nMD5 Chesksum {filepath} ...")
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print(f"File exists, skip downloading. ({filepath})")
return filepath
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
if rm_tar:
os.remove(filepath)
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment