Commit d7cad875 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new files

parents
Pipeline #1560 failed with stages
in 0 seconds
#Fri Jun 09 11:28:08 CST 2023
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
pluginManagement {
repositories {
google()
mavenCentral()
gradlePluginPortal()
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
}
}
rootProject.name = "AndroidDemo"
include ':app'
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# 微调Whisper语音识别模型和加速推理
简体中文 | [English](./README_en.md)
![python version](https://img.shields.io/badge/python-3.8+-orange.svg)
![GitHub forks](https://img.shields.io/github/forks/yeyupiaoling/Whisper-Finetune)
![GitHub Repo stars](https://img.shields.io/github/stars/yeyupiaoling/Whisper-Finetune)
![GitHub](https://img.shields.io/github/license/yeyupiaoling/Whisper-Finetune)
![支持系统](https://img.shields.io/badge/支持系统-Win/Linux/MAC-9cf)
## 前言
OpenAI在开源了号称其英文语音辨识能力已达到人类水准的Whisper项目,且它亦支持其它98种语言的自动语音辨识。Whisper所提供的自动语音识与翻译任务,它们能将各种语言的语音变成文本,也能将这些文本翻译成英文。本项目主要的目的是为了对Whisper模型使用Lora进行微调,**支持无时间戳数据训练,有时间戳数据训练、无语音数据训练**。目前开源了好几个模型,具体可以在[openai](https://huggingface.co/openai)查看,下面列出了常用的几个模型。另外项目最后还支持CTranslate2加速推理和GGML加速推理,提示一下,加速推理支持直接使用Whisper原模型转换,并不一定需要微调。支持Windows桌面应用,Android应用和服务器部署。
### 请先点 :star:
## 支持模型
- openai/whisper-tiny
- openai/whisper-base
- openai/whisper-small
- openai/whisper-medium
- openai/whisper-large
- openai/whisper-large-v2
- openai/whisper-large-v3
**欢迎大家扫码入知识星球(左)或者QQ群(右)讨论,知识星球里面提供项目的模型文件和博主其他相关项目的模型文件,也包括其他一些资源。**
<div align="center">
<img src="https://yeyupiaoling.cn/zsxq.png" alt="知识星球" width="400">
<img src="https://yeyupiaoling.cn/qq.png" alt="QQ群" width="400">
</div>
**使用环境:**
- Anaconda 3
- Python 3.8
- Pytorch 1.13.1
- Ubuntu 18.04
- GPU A100-PCIE-40GB*1
### 视频讲解:[哔哩哔哩](https://www.bilibili.com/video/BV1S8411o7rm/)
### 演示地址:[Web部署](https://whisper.yeyupiaoling.cn:8082/)
## 目录
- [项目主要程序介绍](#项目主要程序介绍)
- [模型测试表](#模型测试表)
- [安装环境](#安装环境)
- [准备数据](#准备数据)
- [微调模型](#微调模型)
- [单卡训练](#单卡训练)
- [多卡训练](#多卡训练)
- [合并模型](#合并模型)
- [评估模型](#评估模型)
- [预测](#预测)
- [GUI界面预测](#GUI界面预测)
- [Web部署](#Web部署)
- [接口文档](#接口文档)
- [使用Ctranslate2格式模型预测](#使用Ctranslate2格式模型预测)
- [Android部署](#Android部署)
- [Windows桌面应用](#Windows桌面应用)
- [打赏作者](#打赏作者)
<a name='项目主要程序介绍'></a>
## 项目主要程序介绍
1. `aishell.py`:制作AIShell训练数据。
2. `finetune.py`:微调模型。
3. `merge_lora.py`:合并Whisper和Lora的模型。
4. `evaluation.py`:评估使用微调后的模型或者Whisper原模型。
5. `infer.py`:使用调用微调后的模型或者transformers上的Whisper模型预测。
6. `infer_ct2.py`:使用转换为CTranslate2的模型预测,主要参考这个程序用法。
7. `infer_gui.py`:有GUI界面操作,使用调用微调后的模型或者transformers上的Whisper模型预测。
8. `infer_server.py`:使用调用微调后的模型或者transformers上的Whisper模型部署到服务器端,提供给客户端调用。
9. `convert-ggml.py`:转换模型为GGML格式模型,给Android应用或者Windows应用使用。
10. `AndroidDemo`:该目录存放的是部署模型到Android的源码。
11. `WhisperDesktop`:该目录存放的是Windows桌面应用的程序。
<a name='模型测试表'></a>
## 模型测试表
1. 原始模型字错率测试表。
| 使用模型 | 指定语言 | aishell_test | test_net | test_meeting | 粤语测试集 | 模型获取 |
|:----------------:|:-------:|:------------:|:--------:|:------------:|:-------:|:--------:|
| whisper-tiny | Chinese | 0.31898 | 0.40482 | 0.75332 | N/A | 加入知识星球获取 |
| whisper-base | Chinese | 0.22196 | 0.30404 | 0.50378 | N/A | 加入知识星球获取 |
| whisper-small | Chinese | 0.13897 | 0.18417 | 0.31154 | N/A | 加入知识星球获取 |
| whisper-medium | Chinese | 0.09538 | 0.13591 | 0.26669 | N/A | 加入知识星球获取 |
| whisper-large | Chinese | 0.08969 | 0.12933 | 0.23439 | N/A | 加入知识星球获取 |
| whisper-large-v2 | Chinese | 0.08817 | 0.12332 | 0.26547 | N/A | 加入知识星球获取 |
| whisper-large-v3 | Chinese | 0.08086 | 0.11452 | 0.19878 | 0.18782 | 加入知识星球获取 |
2. 微调数据集后字错率测试表。
| 使用模型 | 指定语言 | 数据集 | aishell_test | test_net | test_meeting | 粤语测试集 | 模型获取 |
|:----------------:|:---------:|:----------------------------------------------------------:|:------------:|:--------:|:------------:|:-------:|:--------:|
| whisper-tiny | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.13043 | 0.4463 | 0.57728 | N/A | 加入知识星球获取 |
| whisper-base | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.08999 | 0.33089 | 0.40713 | N/A | 加入知识星球获取 |
| whisper-small | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.05452 | 0.19831 | 0.24229 | N/A | 加入知识星球获取 |
| whisper-medium | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.03681 | 0.13073 | 0.16939 | N/A | 加入知识星球获取 |
| whisper-large-v2 | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.03139 | 0.12201 | 0.15776 | N/A | 加入知识星球获取 |
| whisper-large-v3 | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.03660 | 0.09835 | 0.13706 | 0.20060 | 加入知识星球获取 |
| whisper-large-v3 | Cantonese | 粤语数据集 | 0.06857 | 0.11369 | 0.17452 | 0.03524 | 加入知识星球获取 |
| whisper-tiny | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.17711 | 0.24783 | 0.39226 | N/A | 加入知识星球获取 |
| whisper-base | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.14548 | 0.17747 | 0.30590 | N/A | 加入知识星球获取 |
| whisper-small | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.08484 | 0.11801 | 0.23471 | N/A | 加入知识星球获取 |
| whisper-medium | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.05861 | 0.08794 | 0.19486 | N/A | 加入知识星球获取 |
| whisper-large-v2 | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.05443 | 0.08367 | 0.19087 | N/A | 加入知识星球获取 |
| whisper-large-v3 | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.04947 | 0.10711 | 0.17429 | 0.47431 | 加入知识星球获取 |
3. 推理速度测试表,使用GPU为GTX3090(24G),音频为`test_long.wav`,时长为3分钟整,测试程序在`tools/run_compute.sh`
| 加速方式 | tiny | base | small | medium | large-v2 | large-v3 |
|:-------------------------------------------------------------------------:|:------:|:------:|:------:|:-------:|:--------:|:--------:|
| Transformers (`fp16` + `batch_size=16`) | 1.458s | 1.671s | 2.331s | 11.071s | 4.779s | 12.826s |
| Transformers (`fp16` + `batch_size=16` + `Compile`) | 1.477s | 1.675s | 2.357s | 11.003s | 4.799s | 12.643s |
| Transformers (`fp16` + `batch_size=16` + `BetterTransformer`) | 1.461s | 1.676s | 2.301s | 11.062s | 4.608s | 12.505s |
| Transformers (`fp16` + `batch_size=16` + `Flash Attention 2`) | 1.436s | 1.630s | 2.258s | 10.533s | 4.344s | 11.651s |
| Transformers (`fp16` + `batch_size=16` + `Compile` + `BetterTransformer`) | 1.442s | 1.686s | 2.277s | 11.000s | 4.543s | 12.592s |
| Transformers (`fp16` + `batch_size=16` + `Compile` + `Flash Attention 2`) | 1.409s | 1.643s | 2.220s | 10.390s | 4.377s | 11.703s |
| Faster Whisper (`fp16` + `beam_size=1` ) | 2.179s | 1.492s | 2.327s | 3.752s | 5.677s | 31.541s |
| Faster Whisper (`8-bit` + `beam_size=1` ) | 2.609s | 1.728s | 2.744s | 4.688s | 6.571s | 29.307s |
4. 经过处理的数据列表。
| 数据列表处理方式 | AiShell | WenetSpeech |
|:----------:|:--------:|:-----------:|
| 添加标点符号 | 加入知识星球获取 | 加入知识星球获取 |
| 添加标点符号和时间戳 | 加入知识星球获取 | 加入知识星球获取 |
**重要说明:**
1. 在评估的时候移除模型输出的标点符号,并把繁体中文转成简体中文。
2. `aishell_test`为AIShell的测试集,`test_net``test_meeting`为WenetSpeech的测试集。
3. 测试速度的音频为`dataset/test_long.wav`,时长为3分钟整。
4. 训练数据使用的是带标点符号的数据,字错率高一点。
5. 微调AiShell数据不带时间戳,微调WenetSpeech带时间戳。
<a name='安装环境'></a>
## 安装环境
- 首先安装的是Pytorch的GPU版本,以下介绍两种安装Pytorch的方式,只需要选择一种即可。
1. 以下是使用Anaconda安装Pytorch环境,如果已经安装过了,请跳过。
```shell
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
```
2. 以下是使用Docker镜像,拉取一个Pytorch环境的镜像。
```shell
sudo docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
```
然后进入到镜像中,同时将当前路径挂载到容器的`/workspace`目录下。
```shell
sudo nvidia-docker run --name pytorch -it -v $PWD:/workspace pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel /bin/bash
```
- 安装所需的依赖库。
```shell
python -m pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
- Windows需要单独安装bitsandbytes。
```shell
python -m pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
```
<a name='准备数据'></a>
## 准备数据
训练的数据集如下,是一个jsonlines的数据列表,也就是每一行都是一个JSON数据,数据格式如下。本项目提供了一个制作AIShell数据集的程序`aishell.py`,执行这个程序可以自动下载并生成如下列格式的训练集和测试集,**注意:** 这个程序可以通过指定AIShell的压缩文件来跳过下载过程的,如果直接下载会非常慢,可以使用一些如迅雷等下载器下载该数据集,然后通过参数`--filepath`指定下载的压缩文件路径,如`/home/test/data_aishell.tgz`
**小提示:**
1. 如果不使用时间戳训练,可以不包含`sentences`字段的数据。
2. 如果只有一种语言的数据,可以不包含`language`字段数据。
3. 如果训练空语音数据,`sentences`字段为`[]``sentence`字段为`""``language`字段可以不存在。
4. 数据可以不包含标点符号,但微调的模型会损失添加符号能力。
```json
{
"audio": {
"path": "dataset/0.wav"
},
"sentence": "近几年,不但我用书给女儿压岁,也劝说亲朋不要给女儿压岁钱,而改送压岁书。",
"language": "Chinese",
"sentences": [
{
"start": 0,
"end": 1.4,
"text": "近几年,"
},
{
"start": 1.42,
"end": 8.4,
"text": "不但我用书给女儿压岁,也劝说亲朋不要给女儿压岁钱,而改送压岁书。"
}
],
"duration": 7.37
}
```
<a name='微调模型'></a>
## 微调模型
准备好数据之后,就可以开始微调模型了。训练最重要的两个参数分别是,`--base_model`指定微调的Whisper模型,这个参数值需要在[HuggingFace](https://huggingface.co/openai)存在的,这个不需要提前下载,启动训练时可以自动下载,当然也可以提前下载,那么`--base_model`指定就是路径,同时`--local_files_only`设置为True。第二个`--output_path`是是训练时保存的Lora检查点路径,因为我们使用Lora来微调模型。如果想存足够的话,最好将`--use_8bit`设置为False,这样训练速度快很多。其他更多的参数请查看这个程序。
<a name='单卡训练'></a>
### 单卡训练
单卡训练命令如下,Windows系统可以不添加`CUDA_VISIBLE_DEVICES`参数。
```shell
CUDA_VISIBLE_DEVICES=0 python finetune.py --base_model=openai/whisper-tiny --output_dir=output/
```
<a name='多卡训练'></a>
### 多卡训练
多卡训练有两种方法,分别是torchrun和accelerate,开发者可以根据自己的习惯使用对应的方式。
1. 使用torchrun启动多卡训练,命令如下,通过`--nproc_per_node`指定使用的显卡数量。
```shell
torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-tiny --output_dir=output/
```
2. 使用accelerate启动多卡训练,如果是第一次使用accelerate,要配置训练参数,方式如下。
首先配置训练参数,过程是让开发者回答几个问题,基本都是默认就可以,但有几个参数需要看实际情况设置。
```shell
accelerate config
```
大概过程就是这样:
```
--------------------------------------------------------------------In which compute environment are you running?
This machine
--------------------------------------------------------------------Which type of machine are you using?
multi-GPU
How many different machines will you use (use more than 1 for multi-node training)? [1]:
Do you wish to optimize your script with torch dynamo?[yes/NO]:
Do you want to use DeepSpeed? [yes/NO]:
Do you want to use FullyShardedDataParallel? [yes/NO]:
Do you want to use Megatron-LM ? [yes/NO]:
How many GPU(s) should be used for distributed training? [1]:2
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:
--------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
fp16
accelerate configuration saved at /home/test/.cache/huggingface/accelerate/default_config.yaml
```
配置完成之后,可以使用以下命令查看配置。
```shell
accelerate env
```
开始训练命令如下。
```shell
accelerate launch finetune.py --base_model=openai/whisper-tiny --output_dir=output/
```
输出日志如下:
```shell
{'loss': 0.9098, 'learning_rate': 0.000999046843662503, 'epoch': 0.01}
{'loss': 0.5898, 'learning_rate': 0.0009970611012927184, 'epoch': 0.01}
{'loss': 0.5583, 'learning_rate': 0.0009950753589229333, 'epoch': 0.02}
{'loss': 0.5469, 'learning_rate': 0.0009930896165531485, 'epoch': 0.02}
{'loss': 0.5959, 'learning_rate': 0.0009911038741833634, 'epoch': 0.03}
```
<a name='合并模型'></a>
## 合并模型
微调完成之后会有两个模型,第一个是Whisper基础模型,第二个是Lora模型,需要把这两个模型合并之后才能之后的操作。这个程序只需要传递两个参数,`--lora_model`指定的是训练结束后保存的Lora模型路径,其实就是检查点文件夹路径,第二个`--output_dir`是合并后模型的保存目录。
```shell
python merge_lora.py --lora_model=output/whisper-tiny/checkpoint-best/ --output_dir=models/
```
<a name='评估模型'></a>
## 评估模型
执行以下程序进行评估模型,最重要的两个参数分别是。第一个`--model_path`指定的是合并后的模型路径,同时也支持直接使用Whisper原模型,例如直接指定`openai/whisper-large-v2`,第二个是`--metric`指定的是评估方法,例如有字错率`cer`和词错率`wer`**提示:** 没有微调的模型,可能输出带有标点符号,影响准确率。其他更多的参数请查看这个程序。
```shell
python evaluation.py --model_path=models/whisper-tiny-finetune --metric=cer
```
<a name='预测'></a>
## 预测
执行以下程序进行语音识别,这个使用transformers直接调用微调后的模型或者Whisper原模型预测,支持Pytorch2.0的编译器加速、FlashAttention2加速、BetterTransformer加速。第一个`--audio_path`参数指定的是要预测的音频路径。第二个`--model_path`指定的是合并后的模型路径,同时也支持直接使用Whisper原模型,例如直接指定`openai/whisper-large-v2`。其他更多的参数请查看这个程序。
```shell
python infer.py --audio_path=dataset/test.wav --model_path=models/whisper-tiny-finetune
```
<a name='GUI界面预测'></a>
## GUI界面预测
`--model_path`指定Transformers模型。其他更多的参数请查看这个程序。
```shell
python infer_gui.py --model_path=models/whisper-tiny-finetune
```
启动后界面如下:
<div align="center">
<img src="./docs/images/gui.jpg" alt="GUI界面" width="600"/>
</div>
<a name='Web部署'></a>
## Web部署
`--host`指定服务启动的地址,这里设置为`0.0.0.0`,即任何地址都可以访问。`--port`指定使用的端口号。`--model_path`指定的Transformers模型。`--num_workers`指定是使用多少个线程并发推理,这在Web部署上很重要,当有多个并发访问是可以同时推理。其他更多的参数请查看这个程序。
```shell
python infer_server.py --host=0.0.0.0 --port=5000 --model_path=models/whisper-tiny-finetune --num_workers=2
```
### 接口文档
目前提供识别接口`/recognition`,接口参数如下。
| 字段 | 是否必须 | 类型 | 默认值 | 说明 |
|:----------:|:----:|:------:|:----------:|:-----------------------------:|
| audio | 是 | File | | 要识别的音频文件 |
| to_simple | 否 | int | 1 | 是否繁体转简体 |
| remove_pun | 否 | int | 0 | 是否移除标点符号 |
| task | 否 | String | transcribe | 识别任务类型,支持transcribe和translate |
| language | 否 | String | zh | 设置语言,简写,如果为None则自动检测语言 |
返回结果:
| 字段 | 类型 | 说明 |
|:-------:|:----:|:-------------:|
| results | list | 分割的识别结果 |
| +result | str | 每片分隔的文本结果 |
| +start | int | 每片分隔的开始时间,单位秒 |
| +end | int | 每片分隔的结束时间,单位秒 |
| code | int | 错误码,0即为成功识别 |
示例如下:
```json
{
"results": [
{
"result": "近几年,不但我用书给女儿压碎,也全说亲朋不要给女儿压碎钱,而改送压碎书。",
"start": 0,
"end": 8
}
],
"code": 0
}
```
为了方便理解,这里提供了调用Web接口的Python代码,下面的是`/recognition`的调用方式。
```python
import requests
response = requests.post(url="http://127.0.0.1:5000/recognition",
files=[("audio", ("test.wav", open("dataset/test.wav", 'rb'), 'audio/wav'))],
json={"to_simple": 1, "remove_pun": 0, "language": "zh", "task": "transcribe"}, timeout=20)
print(response.text)
```
提供的测试页面如下:
首页`http://127.0.0.1:5000/` 的页面如下:
<div align="center">
<img src="./docs/images/web.jpg" alt="首页" width="600"/>
</div>
文档页面`http://127.0.0.1:5000/docs` 的页面如下:
<a name='使用Ctranslate2格式模型预测'></a>
## 使用Ctranslate2格式模型预测
这里提供了一个CTranslate2加速的方式,尽管使用Transformers的pipeline推理速度已经很快了,首先要转换模型,把合并后的模型转换为CTranslate2模型。如下命令,`--model`参数指定的是合并后的模型路径,同时也支持直接使用Whisper原模型,例如直接指定`openai/whisper-large-v2``--output_dir`参数指定的是转换后的CTranslate2模型路径,`--quantization`参数指定的是量化模型大小,不希望量化模型的可以直接去掉这个参数。
```shell
ct2-transformers-converter --model models/whisper-tiny-finetune --output_dir models/whisper-tiny-finetune-ct2 --copy_files tokenizer.json preprocessor_config.json --quantization float16
```
执行以下程序进行语音识别,`--audio_path`参数指定的是要预测的音频路径。`--model_path`指定的是转换后的CTranslate2模型。其他更多的参数请查看这个程序。
```shell
python infer_ct2.py --audio_path=dataset/test.wav --model_path=models/whisper-tiny-finetune-ct2
```
输出结果如下:
```shell
----------- Configuration Arguments -----------
audio_path: dataset/test.wav
model_path: models/whisper-tiny-finetune-ct2
language: zh
use_gpu: True
use_int8: False
beam_size: 10
num_workers: 1
vad_filter: False
local_files_only: True
------------------------------------------------
[0.0 - 8.0]:近几年,不但我用书给女儿压碎,也全说亲朋不要给女儿压碎钱,而改送压碎书。
```
<a name='Android部署'></a>
## Android部署
安装部署的源码在[AndroidDemo](./AndroidDemo)目录下,具体文档可以到该目录下的[README.md](AndroidDemo/README.md)查看。
<br/>
<div align="center">
<img src="./docs/images/android2.jpg" alt="Android效果图" width="200">
<img src="./docs/images/android1.jpg" alt="Android效果图" width="200">
<img src="./docs/images/android3.jpg" alt="Android效果图" width="200">
<img src="./docs/images/android4.jpg" alt="Android效果图" width="200">
</div>
<a name='Windows桌面应用'></a>
## Windows桌面应用
程序在[WhisperDesktop](./WhisperDesktop)目录下,具体文档可以到该目录下的[README.md](WhisperDesktop/README.md)查看。
<br/>
<div align="center">
<img src="./docs/images/desktop1.jpg" alt="Windows桌面应用效果图">
</div>
<a name='打赏作者'></a>
## 打赏作者
<br/>
<div align="center">
<p>打赏一块钱支持一下作者</p>
<img src="https://yeyupiaoling.cn/reward.png" alt="打赏作者" width="400">
</div>
## 参考资料
1. https://github.com/huggingface/peft
2. https://github.com/guillaumekln/faster-whisper
3. https://github.com/ggerganov/whisper.cpp
4. https://github.com/Const-me/Whisper
# Fine-tune Whisper speech recognition models and speed up reasoning
[简体中文](./README.md) | English
![python version](https://img.shields.io/badge/python-3.8+-orange.svg)
![GitHub forks](https://img.shields.io/github/forks/yeyupiaoling/Whisper-Finetune)
![GitHub Repo stars](https://img.shields.io/github/stars/yeyupiaoling/Whisper-Finetune)
![GitHub](https://img.shields.io/github/license/yeyupiaoling/Whisper-Finetune)
![支持系统](https://img.shields.io/badge/支持系统-Win/Linux/MAC-9cf)
**Disclaimer, this document was obtained through machine translation, please check the original document [here](./README.md).**
## Introduction
OpenAI open-sourced project Whisper, which claims to have human-level speech recognition in English, and it also supports automatic speech recognition in 98 other languages. Whisper provides automatic speech recognition and translation tasks. They can turn speech into text in various languages and translate that text into English. The main purpose of this project is to fine-tune the Whisper model using Lora. It supports training on non-timestamped data, with timestamped data, and without speech data. Currently open source for several models, specific can be [openai](https://huggingface.co/openai) to view, the following is a list of commonly used several models. In addition, the project also supports CTranslate2 accelerated reasoning and GGML accelerated reasoning. As a hint, accelerated reasoning supports direct use of Whisper original model transformation, and does not necessarily need to be fine-tuned. Supports Windows desktop applications, Android applications, and server deployments.
### please :star:
## Supporting models
- openai/whisper-tiny
- openai/whisper-base
- openai/whisper-small
- openai/whisper-medium
- openai/whisper-large
- openai/whisper-large-v2
**Environment:**
- Anaconda 3
- Python 3.8
- Pytorch 1.13.1
- Ubuntu 18.04
- GPU A100-PCIE-40GB*1
## Catalogue
- [Introduction of the main program of the project](#项目主要程序介绍)
- [Test table](#模型测试表)
- [Install](#安装环境)
- [Prepare data](#准备数据)
- [Fine-tuning](#微调模型)
- [Single-GPU](#单卡训练)
- [Multi-GPU](#多卡训练)
- [Merge model](#合并模型)
- [Evaluation](#评估模型)
- [Inference](#预测)
- [GUI inference](#GUI界面预测)
- [Web deploy](#Web部署)
- [API docs](#接口文档)
- [Ctranslate2 inference](#使用Ctranslate2格式模型预测)
- [Android](#Android部署)
- [Windows Desktop](#Windows桌面应用)
<a name='项目主要程序介绍'></a>
## Introduction of the main program of the project
1. `aishell.py`: Create AIShell training data.
2. `finetune.py`: Fine-tune the model.
3. `merge_lora.py`: Merge Whisper and Lora models.
4. `evaluation.py`: Evaluate the fine-tuned model or the original Whisper model.
5. `infer.py`: Call the fine-tuned model or Whisper model on transformers for prediction.
6. `infer_ct2.py`: Use the converted CTranslate2 model for prediction, primarily as a reference for program usage.
7. `infer_gui.py`: There is a GUI to make predictions using either the fine-tuned model or the Whisper model on transformers.
8. `infer_server.py`: Call the fine-tuned model or Whisper model on transformers and deploy it to the server for the client to call.
9. `convert-ggml.py`: Converts the model to GGML format for use in Android or Windows applications.
10. `AndroidDemo`: Contains the source code for deploying the model to Android.
11. `WhisperDesktop`: Contains the program for the Windows desktop application.
<a name='模型测试表'></a>
## Test table
1. Test table for cer of the original model.
| Model | Language | aishell_test | test_net | test_meeting |
|:----------------:|:--------:|:------------:|:--------:|:------------:|
| whisper-tiny | Chinese | 0.31898 | 0.40482 | 0.75332 |
| whisper-base | Chinese | 0.22196 | 0.30404 | 0.50378 |
| whisper-small | Chinese | 0.13897 | 0.18417 | 0.31154 |
| whisper-medium | Chinese | 0.09538 | 0.13591 | 0.26669 |
| whisper-large | Chinese | 0.08969 | 0.12933 | 0.23439 |
| whisper-large-v2 | Chinese | 0.08817 | 0.12332 | 0.26547 |
| whisper-large-v3 | Chinese | 0.08086 | 0.11452 | 0.19878 |
2. Cer test table after fine-tuning the dataset.
| Model | Language | Dataset | aishell_test | test_net | test_meeting |
|:----------------:|:--------:|:----------------------------------------------------------:|:------------:|:--------:|:------------:|
| whisper-tiny | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.13043 | 0.4463 | 0.57728 |
| whisper-base | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.08999 | 0.33089 | 0.40713 |
| whisper-small | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.05452 | 0.19831 | 0.24229 |
| whisper-medium | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.03681 | 0.13073 | 0.16939 |
| whisper-large-v2 | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.03139 | 0.12201 | 0.15776 |
| whisper-large-v2 | Chinese | [AIShell](https://openslr.magicdatatech.com/resources/33/) | 0.03660 | 0.09835 | 0.13706 |
| whisper-tiny | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.21009 | 0.29352 | 0.41506 |
| whisper-base | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.14548 | 0.17747 | 0.30590 |
| whisper-small | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.08484 | 0.11801 | 0.23471 |
| whisper-medium | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.05861 | 0.08794 | 0.19486 |
| whisper-large-v2 | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.05443 | 0.08367 | 0.19087 |
| whisper-large-v3 | Chinese | [WenetSpeech](./tools/create_wenetspeech_data.py) | 0.04947 | 0.10711 | 0.17429 |
3. inference speed test table, using the GPU GTX3090 (24G), The audio is' test long.wav 'and is 3 minutes long. Test in `'tools/run_compute.sh`.
| Mode of acceleration | tiny | base | small | medium | large-v2 | large-v3 |
|:-------------------------------------------------------------------------:|:------:|:------:|:------:|:-------:|:--------:|:--------:|
| Transformers (`fp16` + `batch_size=16`) | 1.458s | 1.671s | 2.331s | 11.071s | 4.779s | 12.826s |
| Transformers (`fp16` + `batch_size=16` + `Compile`) | 1.477s | 1.675s | 2.357s | 11.003s | 4.799s | 12.643s |
| Transformers (`fp16` + `batch_size=16` + `BetterTransformer`) | 1.461s | 1.676s | 2.301s | 11.062s | 4.608s | 12.505s |
| Transformers (`fp16` + `batch_size=16` + `Flash Attention 2`) | 1.436s | 1.630s | 2.258s | 10.533s | 4.344s | 11.651s |
| Transformers (`fp16` + `batch_size=16` + `Compile` + `BetterTransformer`) | 1.442s | 1.686s | 2.277s | 11.000s | 4.543s | 12.592s |
| Transformers (`fp16` + `batch_size=16` + `Compile` + `Flash Attention 2`) | 1.409s | 1.643s | 2.220s | 10.390s | 4.377s | 11.703s |
| Faster Whisper (`fp16` + `beam_size=1` ) | 2.179s | 1.492s | 2.327s | 3.752s | 5.677s | 31.541s |
| Faster Whisper (`8-bit` + `beam_size=1` ) | 2.609s | 1.728s | 2.744s | 4.688s | 6.571s | 29.307s |
**Important explanation**:
1. Remove the punctuation marks from the model output during evaluation, and convert traditional Chinese to simplified Chinese.
2. `aishell_test` is the test set of AIShell, while `test_net` and `test_meeting` are the test sets of WenetSpeech.
3. The audio for testing speed is `dataset/test_long.wav`, with a audio is' test long.wav 'and is 3 minutes long.
4. The training data uses data with punctuation marks, resulting in a slightly higher cer.
5. The AiShell data used for fine-tuning does not include timestamp information, while the WenetSpeech data used for fine-tuning includes timestamp information.
<a name='安装环境'></a>
## 安装环境
- The GPU version of Pytorch will be installed first. You can choose one of two ways to install Pytorch.
1. Here's how to install Pytorch using Anaconda. If you already have it, please skip it.
```shell
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
```
2. Here's how to pull an image of a Pytorch environment using a Docker image.
```shell
sudo docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
```
It then moves into the image and mounts the current path to the container's '/workspace' directory.
```shell
sudo nvidia-docker run --name pytorch -it -v $PWD:/workspace pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel /bin/bash
```
- Install the required libraries.
```shell
python -m pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
- Windows requires a separate installation of bitsandbytes.
```shell
python -m pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
```
<a name='准备数据'></a>
## Prepare the data
The training dataset is a list of jsonlines, meaning that each line is a JSON data in the following format: This project provides a program to make the AIShell dataset, 'aishell.py'. Executing this program will automatically download and generate the training and test sets in the following format. This program can skip the download process by specifying the compressed file of AIShell. If the direct download would be very slow, you can use some downloader such as thunderbolt to download the dataset and then specify the compressed filepath through the '--filepath' parameter. Like `/home/test/data_aishell.tgz`.
**Note:**
1. If timestamp training is not used, the `sentences` field can be excluded from the data.
2. If data is only available for one language, the language field can be excluded from the data.
3. If training empty speech data, the `sentences` field should be `[]`, the `sentence` field should be `""`, and the language field can be absent.
4. Data may exclude punctuation marks, but the fine-tuned model may lose the ability to add punctuation marks.
```json
{
"audio": {
"path": "dataset/0.wav"
},
"sentence": "近几年,不但我用书给女儿压岁,也劝说亲朋不要给女儿压岁钱,而改送压岁书。",
"language": "Chinese",
"sentences": [
{
"start": 0,
"end": 1.4,
"text": "近几年,"
},
{
"start": 1.42,
"end": 8.4,
"text": "不但我用书给女儿压岁,也劝说亲朋不要给女儿压岁钱,而改送压岁书。"
}
],
"duration": 7.37
}
```
<a name='微调模型'></a>
## Fine-tune
Once we have our data ready, we are ready to fine-tune our model. Training is the most important two parameters, respectively, `--base_model` specified fine-tuning the Whisper of model, the parameter values need to be in [HuggingFace](https://huggingface.co/openai), the don't need to download in advance, It can be downloaded automatically when starting training, or in advance, if `--base_model` is specified as the path and `--local_files_only`is set to True. The second `--output_path` is the Lora checkpoint path saved during training as we use Lora to fine-tune the model. If you want to save enough, it's best to set `--use_8bit` to False, which makes training much faster. See this program for more parameters.
<a name='单卡训练'></a>
### Single-GPU
The single card training command is as follows. Windows can do this without the `CUDA_VISIBLE_DEVICES` parameter.
```shell
CUDA_VISIBLE_DEVICES=0 python finetune.py --base_model=openai/whisper-tiny --output_dir=output/
```
<a name='多卡训练'></a>
### Multi-GPU
torchrun and accelerate are two different methods for multi-card training, which developers can use according to their preferences.
1. To start multi-card training with torchrun, use `--nproc_per_node` to specify the number of graphics cards to use.
```shell
torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-tiny --output_dir=output/
```
2. Start multi-card training with accelerate, and if this is the first time you're using accelerate, configure the training parameters as follows:
The first step is to configure the training parameters. The process is to ask the developer to answer a few questions. Basically, the default is ok, but there are a few parameters that need to be set according to the actual situation.
```shell
accelerate config
```
Here's how it goes:
```
--------------------------------------------------------------------In which compute environment are you running?
This machine
--------------------------------------------------------------------Which type of machine are you using?
multi-GPU
How many different machines will you use (use more than 1 for multi-node training)? [1]:
Do you wish to optimize your script with torch dynamo?[yes/NO]:
Do you want to use DeepSpeed? [yes/NO]:
Do you want to use FullyShardedDataParallel? [yes/NO]:
Do you want to use Megatron-LM ? [yes/NO]:
How many GPU(s) should be used for distributed training? [1]:2
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:
--------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
fp16
accelerate configuration saved at /home/test/.cache/huggingface/accelerate/default_config.yaml
```
Once the configuration is complete, you can view the configuration using the following command:
```shell
accelerate env
```
Start fine-tune:
```shell
accelerate launch finetune.py --base_model=openai/whisper-tiny --output_dir=output/
```
log:
```shell
{'loss': 0.9098, 'learning_rate': 0.000999046843662503, 'epoch': 0.01}
{'loss': 0.5898, 'learning_rate': 0.0009970611012927184, 'epoch': 0.01}
{'loss': 0.5583, 'learning_rate': 0.0009950753589229333, 'epoch': 0.02}
{'loss': 0.5469, 'learning_rate': 0.0009930896165531485, 'epoch': 0.02}
{'loss': 0.5959, 'learning_rate': 0.0009911038741833634, 'epoch': 0.03}
```
<a name='合并模型'></a>
## Merge model
After fine-tuning, there will be two models, the first is the Whisper base model, and the second is the Lora model. These two models need to be merged before the next operation. This program only needs to pass two arguments, `--lora_model` is the path of the Lora model saved after training, which is the checkpoint folder, and the second `--output_dir` is the saved directory of the merged model.
```shell
python merge_lora.py --lora_model=output/whisper-tiny/checkpoint-best/ --output_dir=models/
```
<a name='评估模型'></a>
## Evaluation
The following procedure is performed to evaluate the model, the most important two parameters are respectively. The first `--model_path` specifies the path of the merged model, but also supports direct use of the original whisper model, such as directly specifying `openai/Whisper-large-v2`, and the second `--metric` specifies the evaluation method. For example, there are word error rate `cer` and word error rate `wer`. Note: Models without fine-tuning may have punctuation in their output, affecting accuracy. See this program for more parameters.
```shell
python evaluation.py --model_path=models/whisper-tiny-finetune --metric=cer
```
<a name='预测'></a>
## Inference
Execute the following program for speech recognition, this uses transformers to directly call the fine-tuned model or Whisper's original model prediction, only suitable for reasoning short audio, long speech or refer to the use of `infer_ct2.py`. The first `--audio_path` argument specifies the audio path to predict. The second `--model_path` specifies the path of the merged model. It also allows you to use the original whisper model directly, for example `openai/whisper-large-v2`. See this program for more parameters.
```shell
python infer_tfs.py --audio_path=dataset/test.wav --model_path=models/whisper-tiny-finetune
```
<a name='GUI界面预测'></a>
## GUI inference
`--model_path` specifies Transformers model. See this program for more parameters.
```shell
python infer_gui.py --model_path=models/whisper-tiny-finetune
```
After startup, the screen is as follows:
<div align="center">
<img src="./docs/images/gui.jpg" alt="GUI界面" width="600"/>
</div>
<a name='Web部署'></a>
## Web deploy
`--host` specifies the address where the service will be started, here `0.0.0.0`, which means any address will be accessible. `--port`specifies the port number to use. `--model_path` specifies Transformers model. `--num_workers` specifies how many threads to use for concurrent inference, which is important in Web deployments where multiple concurrent accesses can be inferred at the same time. See this program for more parameters.
```shell
python infer_server.py --host=0.0.0.0 --port=5000 --model_path=models/whisper-tiny-finetune-ct2 --num_workers=2
```
### API docs
At present recognition interface `/recognition`, and the interface parameters are as follows.
| Field | Need | type | Default | Explain |
|:----------:|:----:|:------:|:----------:|:-------------------------------------------------------------------------:|
| audio | Yes | File | | Audio File |
| to_simple | No | int | 1 | Traditional Chinese to Simplified Chinese |
| remove_pun | No | int | 0 | Whether to remove punctuation |
| task | No | String | transcribe | Identify task types and support transcribe and translate |
| language | No | String | zh | Set the language, shorthand, to automatically detect the language if None |
Return result:
| Field | type | Explain |
|:-------:|:----:|:---------------------------------------------------:|
| results | list | Recognition results separated into individual parts |
| +result | str | Text recognition result for each separated part |
| +start | int | Start time in seconds for each separated part |
| +end | int | End time in seconds for each separated part |
| code | int | Error code, 0 indicates successful recognition |
Example:
```json
{
"results": [
{
"result": "近几年,不但我用书给女儿压碎,也全说亲朋不要给女儿压碎钱,而改送压碎书。",
"start": 0,
"end": 8
}
],
"code": 0
}
```
To make it easier to understand, here is the Python code to call the Web interface. Here is how to call `/recognition`.
```python
import requests
response = requests.post(url="http://127.0.0.1:5000/recognition",
files=[("audio", ("test.wav", open("dataset/test.wav", 'rb'), 'audio/wav'))],
json={"to_simple": 1, "remove_pun": 0, "language": "zh", "task": "transcribe"}, timeout=20)
print(response.text)
```
The provided test page is as follows:
The home page `http://127.0.0.1:5000/` looks like this:
<div align="center">
<img src="./docs/images/web.jpg" alt="首页" width="600"/>
</div>
Document page `http://127.0.0.1:5000/docs` page is as follows:
<a name='使用Ctranslate2格式模型预测'></a>
## Ctranslate2 inference
As we all know, directly using the Whisper model reasoning is relatively slow, so here provides a way to accelerate, mainly using CTranslate2 for acceleration, first to transform the model, transform the combined model into CTranslate2 model. In the following command, the `--model` parameter is the path of the merged model, but it is also possible to use the original whisper model directly, such as `openai/whisper-large-v2`. The `--output_dir` parameter specifies the path of the transformed CTranslate2 model, and the `--quantization` parameter quantizes the model size. If you don't want to quantize the model, you can drop this parameter.
```shell
ct2-transformers-converter --model models/whisper-tiny-finetune --output_dir models/whisper-tiny-finetune-ct2 --copy_files tokenizer.json preprocessor_config.json --quantization float16
```
Execute the following program to accelerate speech recognition, where the `--audio_path` argument specifies the audio path to predict. `--model_path` specifies the transformed CTranslate2 model. See this program for more parameters.
```shell
python infer_ct2.py --audio_path=dataset/test.wav --model_path=models/whisper-tiny-finetune-ct2
```
Output:
```shell
----------- Configuration Arguments -----------
audio_path: dataset/test.wav
model_path: models/whisper-tiny-finetune-ct2
language: zh
use_gpu: True
use_int8: False
beam_size: 10
num_workers: 1
vad_filter: False
local_files_only: True
------------------------------------------------
[0.0 - 8.0]:近几年,不但我用书给女儿压碎,也全说亲朋不要给女儿压碎钱,而改送压碎书。
```
<a name='Android部署'></a>
## Android
The source code for the installation and deployment can be found in [AndroidDemo](./AndroidDemo) and the documentation can be found in [README.md](AndroidDemo/README.md).
<br/>
<div align="center">
<img src="./docs/images/android2.jpg" alt="Android效果图" width="200">
<img src="./docs/images/android1.jpg" alt="Android效果图" width="200">
<img src="./docs/images/android3.jpg" alt="Android效果图" width="200">
<img src="./docs/images/android4.jpg" alt="Android效果图" width="200">
</div>
<a name='Windows桌面应用'></a>
## Windows Desktop
The program is in the [WhisperDesktop](./WhisperDesktop) directory, and the documentation can be found in [README.md](WhisperDesktop/README.md).
<br/>
<div align="center">
<img src="./docs/images/desktop1.jpg" alt="Windows桌面应用效果图">
</div>
## Reference
1. https://huggingface.co/blog/fine-tune-whisper
2. https://github.com/huggingface/peft
3. https://github.com/guillaumekln/faster-whisper
4. https://github.com/ggerganov/whisper.cpp
5. https://github.com/Const-me/Whisper
# Windows桌面应用
简体中文 | [English](./README_en.md)
该程序是使用[Whisper](https://github.com/Const-me/Whisper)翻译得的,源码可以前面该项目查看。该程序使用的模型格式是GGML格式,跟Android部署的一样,所以需要转换模型格式才能使用。
## 转换模型
1. 然后开始转换模型,请在`Whisper-Finetune`项目根目录下执行`convert-ggml.py`程序,把模型转换为Android项目所需的ggml格式的模型,需要转换的模型可以是原始的Transformers模型,也可以是微调的模型。
```shell
python convert-ggml.py --model_dir=models/whisper-tiny-finetune/ --output_path=models/whisper-tiny-finetune-ggml.bin
```
## 效果图
效果图如下:
<br/>
<div align="center">
<img src="../docs/images/desktop1.jpg" alt="Windows桌面应用效果图"><br/>
图1:加载模型页面
<br/>
<img src="../docs/images/desktop2.jpg" alt="Windows桌面应用效果图"><br/>
图2:选择音频文件转录
<br/>
<img src="../docs/images/desktop3.jpg" alt="Windows桌面应用效果图"><br/>
图3:录音转录
</div>
# Windows Desktop
[简体中文](./README.md) | English
**Disclaimer, this document was obtained through machine translation, please check the original document [here](./README.md).**
The program was translated using [Whisper](https://github.com/Const-me/Whisper), and the source code can be found in the previous project. The model format is GGML, which is the same as the Android deployment, so you'll need to convert the model format before you can use it.
## Convert model
1. To convert your models, run `convert-ggml.py` from the root of your `Whisper-Finetune` project to convert your models to ggml format for your Android project. The models you need to convert can be original Transformers. It can also be a fine-tuned model.
```shell
python convert-ggml.py --model_dir=models/whisper-tiny-finetune/ --whisper_dir=whisper/ --output_path=models/whisper-tiny-finetune-ggml.bin
```
## Effect picture
效果图如下:
<br/>
<div align="center">
<img src="../docs/images/desktop1.jpg" alt="Windows桌面应用效果图"><br/>
图1:加载模型页面
<br/>
<img src="../docs/images/desktop2.jpg" alt="Windows桌面应用效果图"><br/>
图2:选择音频文件转录
<br/>
<img src="../docs/images/desktop3.jpg" alt="Windows桌面应用效果图"><br/>
图3:录音转录
</div>
import argparse
import json
import os
import functools
import soundfile
from tqdm import tqdm
from utils.utils import download, unpack
from utils.utils import add_arguments, print_arguments
DATA_URL = 'https://openslr.elda.org/resources/33/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("filepath", default=None, type=str, help="压缩包data_aishell.tgz文件路径,不指定会自动下载")
add_arg("target_dir", default="dataset/audio/", type=str, help="存放音频文件的目录")
add_arg("annotation_text", default="dataset/", type=str, help="存放音频标注文件的目录")
add_arg('add_pun', default=False, type=bool, help="是否添加标点符")
args = parser.parse_args()
def create_annotation_text(data_dir, annotation_path):
print('Create Aishell annotation text ...')
if args.add_pun:
import logging
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
inference_pipline = pipeline(task=Tasks.punctuation,
model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
model_revision="v1.0.0")
if not os.path.exists(annotation_path):
os.makedirs(annotation_path)
f_train = open(os.path.join(annotation_path, 'train.json'), 'w', encoding='utf-8')
f_test = open(os.path.join(annotation_path, 'test.json'), 'w', encoding='utf-8')
transcript_path = os.path.join(data_dir, 'transcript', 'aishell_transcript_v0.8.txt')
transcript_dict = {}
with open(transcript_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in tqdm(lines):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# remove space
text = ''.join(text.split())
if args.add_pun:
text = inference_pipline(text_in=text)['text']
transcript_dict[audio_id] = text
# 训练集
data_types = ['train', 'dev']
lines = []
for type in data_types:
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
line = {"audio": {"path": audio_path}, "sentence": text}
lines.append(line)
# 添加音频时长
for i in tqdm(range(len(lines))):
audio_path = lines[i]['audio']['path']
sample, sr = soundfile.read(audio_path)
duration = round(sample.shape[-1] / float(sr), 2)
lines[i]["duration"] = duration
lines[i]["sentences"] = [{"start": 0, "end": duration, "text": lines[i]["sentence"]}]
for line in lines:
f_train.write(json.dumps(line, ensure_ascii=False) + "\n")
# 测试集
audio_dir = os.path.join(data_dir, 'wav', 'test')
lines = []
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
line = {"audio": {"path": audio_path}, "sentence": text}
lines.append(line)
# 添加音频时长
for i in tqdm(range(len(lines))):
audio_path = lines[i]['audio']['path']
sample, sr = soundfile.read(audio_path)
duration = round(sample.shape[-1] / float(sr), 2)
lines[i]["duration"] = duration
lines[i]["sentences"] = [{"start": 0, "end": duration, "text": lines[i]["sentence"]}]
for line in lines:
f_test.write(json.dumps(line, ensure_ascii=False)+"\n")
f_test.close()
f_train.close()
def prepare_dataset(url, md5sum, target_dir, annotation_path, filepath=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
if filepath is None:
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
os.remove(filepath)
else:
print("Skip downloading and unpacking. Aishell data already exists in %s." % target_dir)
create_annotation_text(data_dir, annotation_path)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
annotation_path=args.annotation_text,
filepath=args.filepath)
if __name__ == '__main__':
main()
[
{
"type": "resample",
"params": {
"new_sample_rates": [8000, 32000, 44100]
},
"prob": 0.0
},
{
"type": "noise",
"params": {
"min_snr_dB": 10,
"max_snr_dB": 50,
"noise_dir": "dataset/noise"
},
"prob": 0.2
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.5
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 0.0
},
{
"type": "volume",
"params": {
"min_gain_dBFS": -15,
"max_gain_dBFS": 15
},
"prob": 0.5
}
]
\ No newline at end of file
import argparse
import functools
import json
import os
import struct
import numpy as np
import torch
from transformers import WhisperForConditionalGeneration
from utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("model_dir", type=str, default="models/whisper-tiny-finetune", help="需要转换的模型路径")
add_arg("output_path", type=str, default="models/ggml-model.bin", help="转换保存模型的路径")
add_arg("use_f16", type=bool, default=True, help="是否量化为半精度")
args = parser.parse_args()
print_arguments(args)
conv_map = {
'self_attn.k_proj': 'attn.key',
'self_attn.q_proj': 'attn.query',
'self_attn.v_proj': 'attn.value',
'self_attn.out_proj': 'attn.out',
'self_attn_layer_norm': 'attn_ln',
'encoder_attn.q_proj': 'cross_attn.query',
'encoder_attn.v_proj': 'cross_attn.value',
'encoder_attn.out_proj': 'cross_attn.out',
'encoder_attn_layer_norm': 'cross_attn_ln',
'fc1': 'mlp.0',
'fc2': 'mlp.2',
'final_layer_norm': 'mlp_ln',
'encoder.layer_norm.bias': 'encoder.ln_post.bias',
'encoder.layer_norm.weight': 'encoder.ln_post.weight',
'encoder.embed_positions.weight': 'encoder.positional_embedding',
'decoder.layer_norm.bias': 'decoder.ln.bias',
'decoder.layer_norm.weight': 'decoder.ln.weight',
'decoder.embed_positions.weight': 'decoder.positional_embedding',
'decoder.embed_tokens.weight': 'decoder.token_embedding.weight',
'proj_out.weight': 'decoder.proj.weight',
}
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
encoder = json.load(open(f"{args.model_dir}/vocab.json", "r", encoding="utf8"))
encoder_added = json.load(open(f"{args.model_dir}/added_tokens.json", "r", encoding="utf8"))
hparams = json.load(open(f"{args.model_dir}/config.json", "r", encoding="utf8"))
# 支持large-v3模型
if "max_length" not in hparams.keys():
hparams["max_length"] = hparams["max_target_positions"]
model = WhisperForConditionalGeneration.from_pretrained(args.model_dir)
n_mels = hparams["num_mel_bins"]
with np.load(f"tools/mel_filters.npz") as f:
filters = torch.from_numpy(f[f"mel_{n_mels}"])
tokens = json.load(open(f"{args.model_dir}/vocab.json", "r", encoding="utf8"))
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
fout = open(args.output_path, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["max_source_positions"]))
fout.write(struct.pack("i", hparams["d_model"]))
fout.write(struct.pack("i", hparams["encoder_attention_heads"]))
fout.write(struct.pack("i", hparams["encoder_layers"]))
fout.write(struct.pack("i", hparams["max_length"]))
fout.write(struct.pack("i", hparams["d_model"]))
fout.write(struct.pack("i", hparams["decoder_attention_heads"]))
fout.write(struct.pack("i", hparams["decoder_layers"]))
fout.write(struct.pack("i", hparams["num_mel_bins"]))
fout.write(struct.pack("i", args.use_f16))
fout.write(struct.pack("i", filters.shape[0]))
fout.write(struct.pack("i", filters.shape[1]))
for i in range(filters.shape[0]):
for j in range(filters.shape[1]):
fout.write(struct.pack("f", filters[i][j]))
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
fout.write(struct.pack("i", len(tokens)))
tokens = sorted(tokens.items(), key=lambda x: x[1])
for key in tokens:
text = bytearray([byte_decoder[c] for c in key[0]])
fout.write(struct.pack("i", len(text)))
fout.write(text)
list_vars = model.state_dict()
for name in list_vars.keys():
# this seems to not be used
if name == "proj_out.weight":
print('Skipping', name)
continue
src = name
nn = name
if name != "proj_out.weight":
nn = nn.split(".")[1:]
else:
nn = nn.split(".")
if nn[1] == "layers":
nn[1] = "blocks"
if ".".join(nn[3:-1]) == "encoder_attn.k_proj":
mapped = "attn.key" if nn[0] == "encoder" else "cross_attn.key"
else:
mapped = conv_map[".".join(nn[3:-1])]
name = ".".join(nn[:3] + [mapped] + nn[-1:])
else:
name = ".".join(nn)
name = conv_map[name] if name in conv_map else name
print(src, ' -> ', name)
data = list_vars[src].squeeze().numpy()
data = data.astype(np.float16)
# reshape conv bias from [n] to [n, 1]
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
data = data.reshape(data.shape[0], 1)
print(" Reshaped variable: ", name, " to shape: ", data.shape)
n_dims = len(data.shape)
print(name, n_dims, data.shape)
# looks like the whisper models are in f16 by default
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
# ftype == 0 -> float32, ftype == 1 -> float16
ftype = 1
if args.use_f16:
if n_dims < 2 or \
name == "encoder.conv1.bias" or \
name == "encoder.conv2.bias" or \
name == "encoder.positional_embedding" or \
name == "decoder.positional_embedding":
print(" Converting to float32")
data = data.astype(np.float32)
ftype = 0
else:
data = data.astype(np.float32)
ftype = 0
# header
str_ = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str_)
# data
data.tofile(fout)
fout.close()
print(f"导出模型: {args.output_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment