Commit 431278fa authored by “change”'s avatar “change”
Browse files

Initial commit

parent 8c252776
Pipeline #1949 failed with stages
in 0 seconds
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# tokenize configuration
text_dir=$1
seg_file=$2
logdir=$3
output_dir=$4
txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
-s ${seg_file} -i JOB -o ${txt_dir} \
|| exit 1;
# concatenate the text files together.
for n in $(seq $nj); do
cat ${txt_dir}/text.$n.txt || exit 1
done > ${output_dir}/text || exit 1
for n in $(seq $nj); do
cat ${txt_dir}/len.$n || exit 1
done > ${output_dir}/text_shape || exit 1
echo "$0: Succeeded text tokenize"
#!/usr/bin/env python3
# coding=utf-8
# Authors:
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
# 2019.9 Jiayu DU
#
# requirements:
# - python 3.X
# notes: python 2.X WILL fail or produce misleading results
import sys, os, argparse, codecs, string, re
# ================================================================================ #
# basic constant
# ================================================================================ #
CHINESE_DIGIS = "零一二三四五六七八九"
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
ZERO_ALT = "〇"
ONE_ALT = "幺"
TWO_ALTS = ["两", "兩"]
POSITIVE = ["正", "正"]
NEGATIVE = ["负", "負"]
POINT = ["点", "點"]
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ["呃", "啊"]
ER_WHITELIST = (
"(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
"胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
"儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
"佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
)
# 中文数字系统类型
NUMBERING_TYPES = ["low", "mid", "high"]
CURRENCY_NAMES = (
"(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
"里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
)
CURRENCY_UNITS = (
"((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
)
COM_QUANTIFIERS = (
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
)
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = "!?。。"
CHINESE_PUNC_NON_STOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏"
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
# basic class
# ================================================================================ #
class ChineseChar(object):
"""
中文字符
每个字符对应简体和繁体,
e.g. 简体 = '负', 繁体 = '負'
转换时可转换为简体或繁体
"""
def __init__(self, simplified, traditional):
self.simplified = simplified
self.traditional = traditional
# self.__repr__ = self.__str__
def __str__(self):
return self.simplified or self.traditional or None
def __repr__(self):
return self.__str__()
class ChineseNumberUnit(ChineseChar):
"""
中文数字/数位字符
每个字符除繁简体外还有一个额外的大写字符
e.g. '陆' 和 '陸'
"""
def __init__(self, power, simplified, traditional, big_s, big_t):
super(ChineseNumberUnit, self).__init__(simplified, traditional)
self.power = power
self.big_s = big_s
self.big_t = big_t
def __str__(self):
return "10^{}".format(self.power)
@classmethod
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
if small_unit:
return ChineseNumberUnit(
power=index + 1,
simplified=value[0],
traditional=value[1],
big_s=value[1],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[0]:
return ChineseNumberUnit(
power=index + 8,
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[1]:
return ChineseNumberUnit(
power=(index + 2) * 4,
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[2]:
return ChineseNumberUnit(
power=pow(2, index + 3),
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
else:
raise ValueError(
"Counting type should be in {0} ({1} provided).".format(
NUMBERING_TYPES, numbering_type
)
)
class ChineseNumberDigit(ChineseChar):
"""
中文数字字符
"""
def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
super(ChineseNumberDigit, self).__init__(simplified, traditional)
self.value = value
self.big_s = big_s
self.big_t = big_t
self.alt_s = alt_s
self.alt_t = alt_t
def __str__(self):
return str(self.value)
@classmethod
def create(cls, i, v):
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
class ChineseMath(ChineseChar):
"""
中文数位字符
"""
def __init__(self, simplified, traditional, symbol, expression=None):
super(ChineseMath, self).__init__(simplified, traditional)
self.symbol = symbol
self.expression = expression
self.big_s = simplified
self.big_t = traditional
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
class NumberSystem(object):
"""
中文数字系统
"""
pass
class MathSymbol(object):
"""
用于中文数字系统的数学符号 (繁/简体), e.g.
positive = ['正', '正']
negative = ['负', '負']
point = ['点', '點']
"""
def __init__(self, positive, negative, point):
self.positive = positive
self.negative = negative
self.point = point
def __iter__(self):
for v in self.__dict__.values():
yield v
# class OtherSymbol(object):
# """
# 其他符号
# """
#
# def __init__(self, sil):
# self.sil = sil
#
# def __iter__(self):
# for v in self.__dict__.values():
# yield v
# ================================================================================ #
# basic utils
# ================================================================================ #
def create_system(numbering_type=NUMBERING_TYPES[1]):
"""
根据数字系统类型返回创建相应的数字系统,默认为 mid
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
返回对应的数字系统
"""
# chinese number units of '亿' and larger
all_larger_units = zip(
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
)
larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
# chinese number units of '十, 百, 千, 万'
all_smaller_units = zip(
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
)
smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
# digis
chinese_digis = zip(
CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
)
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
# symbols
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
system = NumberSystem()
system.units = smaller_units + larger_units
system.digits = digits
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
# system.symbols = OtherSymbol(sil_cn)
return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
def get_symbol(char, system):
for u in system.units:
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
return u
for d in system.digits:
if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
return d
for m in system.math:
if char in [m.traditional, m.simplified]:
return m
def string2symbols(chinese_string, system):
int_string, dec_string = chinese_string, ""
for p in [system.math.point.simplified, system.math.point.traditional]:
if p in chinese_string:
int_string, dec_string = chinese_string.split(p)
break
return [get_symbol(c, system) for c in int_string], [
get_symbol(c, system) for c in dec_string
]
def correct_symbols(integer_symbols, system):
"""
一百八 to 一百八十
一亿一千三百万 to 一亿 一千万 三百万
"""
if integer_symbols and isinstance(integer_symbols[0], CNU):
if integer_symbols[0].power == 1:
integer_symbols = [system.digits[1]] + integer_symbols
if len(integer_symbols) > 1:
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
result = []
unit_count = 0
for s in integer_symbols:
if isinstance(s, CND):
result.append(s)
unit_count = 0
elif isinstance(s, CNU):
current_unit = CNU(s.power, None, None, None, None)
unit_count += 1
if unit_count == 1:
result.append(current_unit)
elif unit_count > 1:
for i in range(len(result)):
if (
isinstance(result[-i - 1], CNU)
and result[-i - 1].power < current_unit.power
):
result[-i - 1] = CNU(
result[-i - 1].power + current_unit.power, None, None, None, None
)
return result
def compute_value(integer_symbols):
"""
Compute the value.
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
"""
value = [0]
last_power = 0
for s in integer_symbols:
if isinstance(s, CND):
value[-1] = s.value
elif isinstance(s, CNU):
value[-1] *= pow(10, s.power)
if s.power > last_power:
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
last_power = s.power
value.append(0)
return sum(value)
system = create_system(numbering_type)
int_part, dec_part = string2symbols(chinese_string, system)
int_part = correct_symbols(int_part, system)
int_str = str(compute_value(int_part))
dec_str = "".join([str(d.value) for d in dec_part])
if dec_part:
return "{0}.{1}".format(int_str, dec_str)
else:
return int_str
def num2chn(
number_string,
numbering_type=NUMBERING_TYPES[1],
big=False,
traditional=False,
alt_zero=False,
alt_one=False,
alt_two=True,
use_zeros=True,
use_units=True,
):
def get_value(value_string, use_zeros=True):
striped_string = value_string.lstrip("0")
# record nothing if all zeros
if not striped_string:
return []
# record one digits
elif len(striped_string) == 1:
if use_zeros and len(value_string) != len(striped_string):
return [system.digits[0], system.digits[int(striped_string)]]
else:
return [system.digits[int(striped_string)]]
# recursively record multiple digits
else:
result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
result_string = value_string[: -result_unit.power]
return (
get_value(result_string)
+ [result_unit]
+ get_value(striped_string[-result_unit.power :])
)
system = create_system(numbering_type)
int_dec = number_string.split(".")
if len(int_dec) == 1:
int_string = int_dec[0]
dec_string = ""
elif len(int_dec) == 2:
int_string = int_dec[0]
dec_string = int_dec[1]
else:
raise ValueError(
"invalid input num string with more than one dot: {}".format(number_string)
)
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
else:
result_symbols = [system.digits[int(c)] for c in int_string]
dec_symbols = [system.digits[int(c)] for c in dec_string]
if dec_string:
result_symbols += [system.math.point] + dec_symbols
if alt_two:
liang = CND(
2,
system.digits[2].alt_s,
system.digits[2].alt_t,
system.digits[2].big_s,
system.digits[2].big_t,
)
for i, v in enumerate(result_symbols):
if isinstance(v, CND) and v.value == 2:
next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
previous_symbol = result_symbols[i - 1] if i > 0 else None
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol.power != 1)
):
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
if big:
attr_name = "big_"
if traditional:
attr_name += "t"
else:
attr_name += "s"
else:
if traditional:
attr_name = "traditional"
else:
attr_name = "simplified"
result = "".join([getattr(s, attr_name) for s in result_symbols])
# if not use_zeros:
# result = result.strip(getattr(system.digits[0], attr_name))
if alt_zero:
result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
if alt_one:
result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
for i, p in enumerate(POINT):
if result.startswith(p):
return CHINESE_DIGIS[0] + result
# ^10, 11, .., 19
if (
len(result) >= 2
and result[1]
in [
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
]
and result[0]
in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
):
result = result[1:]
return result
# ================================================================================ #
# different types of rewriters
# ================================================================================ #
class Cardinal:
"""
CARDINAL类
"""
def __init__(self, cardinal=None, chntext=None):
self.cardinal = cardinal
self.chntext = chntext
def chntext2cardinal(self):
return chn2num(self.chntext)
def cardinal2chntext(self):
return num2chn(self.cardinal)
class Digit:
"""
DIGIT类
"""
def __init__(self, digit=None, chntext=None):
self.digit = digit
self.chntext = chntext
# def chntext2digit(self):
# return chn2num(self.chntext)
def digit2chntext(self):
return num2chn(self.digit, alt_two=False, use_units=False)
class TelePhone:
"""
TELEPHONE类
"""
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
self.telephone = telephone
self.raw_chntext = raw_chntext
self.chntext = chntext
# def chntext2telephone(self):
# sil_parts = self.raw_chntext.split('<SIL>')
# self.telephone = '-'.join([
# str(chn2num(p)) for p in sil_parts
# ])
# return self.telephone
def telephone2chntext(self, fixed=False):
if fixed:
sil_parts = self.telephone.split("-")
self.raw_chntext = "<SIL>".join(
[num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
)
self.chntext = self.raw_chntext.replace("<SIL>", "")
else:
sp_parts = self.telephone.strip("+").split()
self.raw_chntext = "<SP>".join(
[num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
)
self.chntext = self.raw_chntext.replace("<SP>", "")
return self.chntext
class Fraction:
"""
FRACTION类
"""
def __init__(self, fraction=None, chntext=None):
self.fraction = fraction
self.chntext = chntext
def chntext2fraction(self):
denominator, numerator = self.chntext.split("分之")
return chn2num(numerator) + "/" + chn2num(denominator)
def fraction2chntext(self):
numerator, denominator = self.fraction.split("/")
return num2chn(denominator) + "分之" + num2chn(numerator)
class Date:
"""
DATE类
"""
def __init__(self, date=None, chntext=None):
self.date = date
self.chntext = chntext
# def chntext2date(self):
# chntext = self.chntext
# try:
# year, other = chntext.strip().split('年', maxsplit=1)
# year = Digit(chntext=year).digit2chntext() + '年'
# except ValueError:
# other = chntext
# year = ''
# if other:
# try:
# month, day = other.strip().split('月', maxsplit=1)
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
# except ValueError:
# day = chntext
# month = ''
# if day:
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
# else:
# month = ''
# day = ''
# date = year + month + day
# self.date = date
# return self.date
def date2chntext(self):
date = self.date
try:
year, other = date.strip().split("年", 1)
year = Digit(digit=year).digit2chntext() + "年"
except ValueError:
other = date
year = ""
if other:
try:
month, day = other.strip().split("月", 1)
month = Cardinal(cardinal=month).cardinal2chntext() + "月"
except ValueError:
day = date
month = ""
if day:
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
else:
month = ""
day = ""
chntext = year + month + day
self.chntext = chntext
return self.chntext
class Money:
"""
MONEY类
"""
def __init__(self, money=None, chntext=None):
self.money = money
self.chntext = chntext
# def chntext2money(self):
# return self.money
def money2chntext(self):
money = self.money
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(money)
if matchers:
for matcher in matchers:
money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
self.chntext = money
return self.chntext
class Percentage:
"""
PERCENTAGE类
"""
def __init__(self, percentage=None, chntext=None):
self.percentage = percentage
self.chntext = chntext
def chntext2percentage(self):
return chn2num(self.chntext.strip().strip("百分之")) + "%"
def percentage2chntext(self):
return "百分之" + num2chn(self.percentage.strip().strip("%"))
def remove_erhua(text, er_whitelist):
"""
去除儿化音词中的儿:
他女儿在那边儿 -> 他女儿在那边
"""
er_pattern = re.compile(er_whitelist)
new_str = ""
while re.search("儿", text):
a = re.search("儿", text).span()
remove_er_flag = 0
if er_pattern.search(text):
b = er_pattern.search(text).span()
if b[0] <= a[0]:
remove_er_flag = 1
if remove_er_flag == 0:
new_str = new_str + text[0 : a[0]]
text = text[a[1] :]
else:
new_str = new_str + text[0 : b[1]]
text = text[b[1] :]
text = new_str + text
return text
# ================================================================================ #
# NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
def __init__(self, raw_text):
self.raw_text = "^" + raw_text + "$"
self.norm_text = ""
def _particular(self):
text = self.norm_text
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text)
if matchers:
# print('particular')
for matcher in matchers:
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
self.norm_text = text
return self.norm_text
def normalize(self):
text = self.raw_text
# 规范化日期
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text)
if matchers:
# print('date')
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
# 规范化金钱
pattern = re.compile(
r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
)
matchers = pattern.findall(text)
if matchers:
# print('money')
for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
# 规范化固话/手机号码
# 手机
# http://www.jihaoba.com/news/show/13680
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通:130、131、132、156、155、186、185、176
# 电信:133、153、189、180、181、177
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
# print('telephone')
for matcher in matchers:
text = text.replace(
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
for matcher in matchers:
text = text.replace(
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
)
# 规范化分数
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
# print('fraction')
for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
# 规范化百分数
text = text.replace("%", "%")
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
# print('percentage')
for matcher in matchers:
text = text.replace(
matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
)
# 规范化纯数+量词
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
# print('cardinal+quantifier')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
# 规范化数字编号
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
# print('digit')
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
# 规范化纯数
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
# print('cardinal')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
self.norm_text = text
self._particular()
return self.norm_text.lstrip("^").rstrip("$")
def nsw_test_case(raw_text):
print("I:" + raw_text)
print("O:" + NSWNormalizer(raw_text).normalize())
print("")
def nsw_test():
nsw_test_case("固话:0595-23865596或23880880。")
nsw_test_case("固话:0595-23865596或23880880。")
nsw_test_case("手机:+86 19859213959或15659451527。")
nsw_test_case("分数:32477/76391。")
nsw_test_case("百分数:80.03%。")
nsw_test_case("编号:31520181154418。")
nsw_test_case("纯数:2983.07克或12345.60米。")
nsw_test_case("日期:1999年2月20日或09年3月15号。")
nsw_test_case("金钱:12块5,34.5元,20.1万")
nsw_test_case("特殊:O2O或B2C。")
nsw_test_case("3456万吨")
nsw_test_case("2938个")
nsw_test_case("938")
nsw_test_case("今天吃了115个小笼包231个馒头")
nsw_test_case("有62%的概率")
if __name__ == "__main__":
# nsw_test()
p = argparse.ArgumentParser()
p.add_argument("ifile", help="input filename, assume utf-8 encoding")
p.add_argument("ofile", help="output filename")
p.add_argument("--to_upper", action="store_true", help="convert to upper case")
p.add_argument("--to_lower", action="store_true", help="convert to lower case")
p.add_argument(
"--has_key", action="store_true", help="input text has Kaldi's key as first field."
)
p.add_argument(
"--remove_fillers", type=bool, default=True, help='remove filler chars such as "呃, 啊"'
)
p.add_argument(
"--remove_erhua", type=bool, default=True, help='remove erhua chars such as "这儿"'
)
p.add_argument(
"--log_interval", type=int, default=10000, help="log interval in number of processed lines"
)
args = p.parse_args()
ifile = codecs.open(args.ifile, "r", "utf8")
ofile = codecs.open(args.ofile, "w+", "utf8")
n = 0
for l in ifile:
key = ""
text = ""
if args.has_key:
cols = l.split(maxsplit=1)
key = cols[0]
if len(cols) == 2:
text = cols[1].strip()
else:
text = ""
else:
text = l.strip()
# cases
if args.to_upper and args.to_lower:
sys.stderr.write("text norm: to_upper OR to_lower?")
exit(1)
if args.to_upper:
text = text.upper()
if args.to_lower:
text = text.lower()
# Filler chars removal
if args.remove_fillers:
for ch in FILLER_CHARS:
text = text.replace(ch, "")
if args.remove_erhua:
text = remove_erhua(text, ER_WHITELIST)
# NSW(Non-Standard-Word) normalization
text = NSWNormalizer(text).normalize()
# Punctuations removal
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
new_chars = " " * len(old_chars)
del_chars = ""
text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
#
if args.has_key:
ofile.write(key + "\t" + text + "\n")
else:
ofile.write(text + "\n")
n += 1
if n % args.log_interval == 0:
sys.stderr.write("text norm: {} lines done.\n".format(n))
sys.stderr.write("text norm: {} lines done in total.\n".format(n))
ifile.close()
ofile.close()
# Conformer Result
## Training Config
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
- Train config: conf/train_asr_transformer.yaml
- LM config: LM was not used
- Model size: 46M
## Results (CER)
| testset | CER(%) |
|:-----------:|:------:|
| dev | 4.97 |
| test | 5.37 |
\ No newline at end of file
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: Transformer
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
# encoder
encoder: TransformerEncoder
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
# decoder
decoder: TransformerDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# frontend related
frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 1
lfr_n: 1
specaug: SpecAug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
train_conf:
accum_grad: 1
grad_clip: 5
max_epoch: 150
keep_nbest_models: 10
log_interval: 50
optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr
scheduler_conf:
warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
index_ds: IndexDSJsonl
batch_sampler: EspnetStyleBatchSampler
batch_type: length # example or length
batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 1024
shuffle: True
num_workers: 4
preprocessor_speech: SpeechPreprocessSpeedPerturb
preprocessor_speech_conf:
speed_perturb: [0.9, 1.0, 1.1]
tokenizer: CharTokenizer
tokenizer_conf:
unk_symbol: <unk>
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin
reduce: true
ignore_nan_grad: true
normalize: null
../paraformer/demo_infer.sh
\ No newline at end of file
../paraformer/demo_train_or_finetune.sh
\ No newline at end of file
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
echo "Usage: $0 <audio-path> <text-path> <output-path>"
echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
echo "Error: $0 requires two directory arguments"
exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
echo Preparing $dir transcriptions
sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
cp $train_dir/$f $output_dir/data/train/$f || exit 1;
cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;
#!/usr/bin/env bash
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
# 2017 Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
remove_archive=true
shift
fi
if [ $# -ne 3 ]; then
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
echo "$0: no such directory $data"
exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
exit 1;
fi
if [ -z "$url" ]; then
echo "$0: empty URL base."
exit 1;
fi
if [ -f $data/$part/.complete ]; then
echo "$0: data part $part was already successfully extracted, nothing to do."
exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
size_ok=false
for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
if ! $size_ok; then
echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
echo "does not equal the size of one of the archives."
rm $data/$part.tgz
else
echo "$data/$part.tgz exists and appears to be complete."
fi
fi
if [ ! -f $data/$part.tgz ]; then
if ! command -v wget >/dev/null; then
echo "$0: wget is not installed."
exit 1;
fi
full_url=$url/$part.tgz
echo "$0: downloading data from $full_url. This may take some time, please be patient."
cd $data || exit 1
if ! wget --no-check-certificate $full_url; then
echo "$0: error executing wget $full_url"
exit 1;
fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
echo "$0: error un-tarring archive $data/$part.tgz"
exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
cd $data/$part/wav || exit 1
for wav in ./*.tar.gz; do
echo "Extracting wav from $wav"
tar -zxf $wav && rm $wav
done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
rm $data/$part.tgz
fi
exit 0;
#!/usr/bin/env bash
CUDA_VISIBLE_DEVICES="0,1"
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir=`pwd`
lang=zh
token_type=char
stage=0
stop_stage=5
# feature configuration
nj=32
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
inference_batch_size=1
# data
raw_data=../raw_data
data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
config=transformer_12e_6d_2048_256.yaml
model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}"
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
mkdir -p ${raw_data}
local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
> ${feats_dir}/data/${x}/text
utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
# convert wav.scp text to jsonl
scp_file_list_arg="++scp_file_list='[\"${feats_dir}/data/${x}/wav.scp\",\"${feats_dir}/data/${x}/text\"]'"
python ../../../funasr/datasets/audio_datasets/scp2jsonl.py \
++data_type_list='["source", "target"]' \
++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl \
${scp_file_list_arg}
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Feature and CMVN Generation"
python ../../../funasr/bin/compute_audio_cmvn.py \
--config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
fi
token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
echo "<unk>" >> ${token_list}
fi
# LM Training Stage
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: LM Training"
fi
# ASR Training Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
current_time=$(date "+%Y-%m-%d_%H-%M")
log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}"
echo "log_file: ${log_file}"
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun \
--nnodes 1 \
--nproc_per_node ${gpu_num} \
--master_port ${master_port} \
../../../funasr/bin/train.py \
--config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++valid_data_set_list="${feats_dir}/data/${valid_set}/audio_datasets.jsonl" \
++tokenizer_conf.token_list="${token_list}" \
++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
++output_dir="${exp_dir}/exp/${model_dir}" &> ${log_file}
fi
# Testing Stage
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "stage 5: Inference"
if [ ${inference_device} == "cuda" ]; then
nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
else
inference_batch_size=1
CUDA_VISIBLE_DEVICES=""
for JOB in $(seq ${nj}); do
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
done
fi
for dset in ${test_sets}; do
inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}"
_logdir="${inference_dir}/logdir"
echo "inference_dir: ${inference_dir}"
mkdir -p "${_logdir}"
data_dir="${feats_dir}/data/${dset}"
key_file=${data_dir}/${inference_scp}
split_scps=
for JOB in $(seq "${nj}"); do
split_scps+=" ${_logdir}/keys.${JOB}.scp"
done
utils/split_scp.pl "${key_file}" ${split_scps}
gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
for JOB in $(seq ${nj}); do
{
id=$((JOB-1))
gpuid=${gpuid_list_array[$id]}
export CUDA_VISIBLE_DEVICES=${gpuid}
python ../../../funasr/bin/inference.py \
--config-path="${exp_dir}/exp/${model_dir}" \
--config-name="config.yaml" \
++init_param="${exp_dir}/exp/${model_dir}/${inference_checkpoint}" \
++tokenizer_conf.token_list="${token_list}" \
++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
++input="${_logdir}/keys.${JOB}.scp" \
++output_dir="${inference_dir}/${JOB}" \
++device="${inference_device}" \
++ncpu=1 \
++disable_log=true \
++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
}&
done
wait
mkdir -p ${inference_dir}/1best_recog
for f in token score text; do
if [ -f "${inference_dir}/${JOB}/1best_recog/${f}" ]; then
for JOB in $(seq "${nj}"); do
cat "${inference_dir}/${JOB}/1best_recog/${f}"
done | sort -k1 >"${inference_dir}/1best_recog/${f}"
fi
done
echo "Computing WER ..."
python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc
python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref
python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer
tail -n 3 ${inference_dir}/1best_recog/text.cer
done
fi
\ No newline at end of file
../paraformer/utils
\ No newline at end of file
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
multilingual_wavs = [
"example_zh-CN.mp3",
"example_en.mp3",
"example_ja.mp3",
"example_ko.mp3",
]
model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch")
for wav_id in multilingual_wavs:
wav_file = f"{model.model_path}/examples/{wav_id}"
res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
print("detect sample {}: {}".format(wav_id, res))
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
multilingual_wavs = [
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3",
]
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition, model="iic/speech_whisper-large_lid_multilingual_pytorch"
)
for wav in multilingual_wavs:
rec_result = inference_pipeline(input=wav, inference_clip_length=250)
print(rec_result)
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 0,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"round_robin_gradients": true
}
}
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
}
}
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
}
}
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e5
}
}
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
res = model.generate(
input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav",
batch_size_s=300,
batch_size_threshold_s=60,
)
print(res)
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
#punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
python funasr/bin/inference.py \
+model=${model} \
+vad_model=${vad_model} \
+punc_model=${punc_model} \
+spk_model=${spk_model} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav" \
+output_dir="./outputs/debug" \
+device="cpu" \
+batch_size_s=300 \
+batch_size_threshold_s=60
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
device="cpu",
)
res = model.export(type="torchscript", quantize=False)
print(res)
# # method2, inference from local path
# from funasr import AutoModel
# model = AutoModel(
# model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
# device="cpu",
# )
# res = model.export(type="onnx", quantize=False)
# print(res)
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
python -m funasr.bin.export \
++model=${model} \
++type="onnx" \
++quantize=false \
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
python -m funasr.bin.export \
++model=${model} \
++type="onnx" \
++quantize=false \
++device="cpu"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment