Commit b75857fb authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# -*- coding: utf-8 -*-
"""基本常量
中文数字/数位/符号字符常量
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-02"
CHINESE_DIGIS = "零一二三四五六七八九"
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
ZERO_ALT = "〇"
ONE_ALT = "幺"
TWO_ALTS = ["两", "兩"]
POSITIVE = ["正", "正"]
NEGATIVE = ["负", "負"]
POINT = ["点", "點"]
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
# 中文数字系统类型
NUMBERING_TYPES = ["low", "mid", "high"]
# -*- coding: utf-8 -*-
"""基本方法
创建中文数字系统 方法
中文字符串 <=> 数字串 方法
数字串 <=> 中文字符串 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-02"
from fish_speech.text.chn_text_norm.basic_class import *
from fish_speech.text.chn_text_norm.basic_constant import *
def create_system(numbering_type=NUMBERING_TYPES[1]):
"""
根据数字系统类型返回创建相应的数字系统,默认为 mid
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
返回对应的数字系统
"""
# chinese number units of '亿' and larger
all_larger_units = zip(
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
)
larger_units = [
CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
]
# chinese number units of '十, 百, 千, 万'
all_smaller_units = zip(
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
)
smaller_units = [
CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
]
# digis
chinese_digis = zip(
CHINESE_DIGIS,
CHINESE_DIGIS,
BIG_CHINESE_DIGIS_SIMPLIFIED,
BIG_CHINESE_DIGIS_TRADITIONAL,
)
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
# symbols
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
system = NumberSystem()
system.units = smaller_units + larger_units
system.digits = digits
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
# system.symbols = OtherSymbol(sil_cn)
return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
def get_symbol(char, system):
for u in system.units:
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
return u
for d in system.digits:
if char in [
d.traditional,
d.simplified,
d.big_s,
d.big_t,
d.alt_s,
d.alt_t,
]:
return d
for m in system.math:
if char in [m.traditional, m.simplified]:
return m
def string2symbols(chinese_string, system):
int_string, dec_string = chinese_string, ""
for p in [system.math.point.simplified, system.math.point.traditional]:
if p in chinese_string:
int_string, dec_string = chinese_string.split(p)
break
return [get_symbol(c, system) for c in int_string], [
get_symbol(c, system) for c in dec_string
]
def correct_symbols(integer_symbols, system):
"""
一百八 to 一百八十
一亿一千三百万 to 一亿 一千万 三百万
"""
if integer_symbols and isinstance(integer_symbols[0], CNU):
if integer_symbols[0].power == 1:
integer_symbols = [system.digits[1]] + integer_symbols
if len(integer_symbols) > 1:
if isinstance(integer_symbols[-1], CND) and isinstance(
integer_symbols[-2], CNU
):
integer_symbols.append(
CNU(integer_symbols[-2].power - 1, None, None, None, None)
)
result = []
unit_count = 0
for s in integer_symbols:
if isinstance(s, CND):
result.append(s)
unit_count = 0
elif isinstance(s, CNU):
current_unit = CNU(s.power, None, None, None, None)
unit_count += 1
if unit_count == 1:
result.append(current_unit)
elif unit_count > 1:
for i in range(len(result)):
if (
isinstance(result[-i - 1], CNU)
and result[-i - 1].power < current_unit.power
):
result[-i - 1] = CNU(
result[-i - 1].power + current_unit.power,
None,
None,
None,
None,
)
return result
def compute_value(integer_symbols):
"""
Compute the value.
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
"""
value = [0]
last_power = 0
for s in integer_symbols:
if isinstance(s, CND):
value[-1] = s.value
elif isinstance(s, CNU):
value[-1] *= pow(10, s.power)
if s.power > last_power:
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
last_power = s.power
value.append(0)
return sum(value)
system = create_system(numbering_type)
int_part, dec_part = string2symbols(chinese_string, system)
int_part = correct_symbols(int_part, system)
int_str = str(compute_value(int_part))
dec_str = "".join([str(d.value) for d in dec_part])
if dec_part:
return "{0}.{1}".format(int_str, dec_str)
else:
return int_str
def num2chn(
number_string,
numbering_type=NUMBERING_TYPES[1],
big=False,
traditional=False,
alt_zero=False,
alt_one=False,
alt_two=True,
use_zeros=True,
use_units=True,
):
def get_value(value_string, use_zeros=True):
striped_string = value_string.lstrip("0")
# record nothing if all zeros
if not striped_string:
return []
# record one digits
elif len(striped_string) == 1:
if use_zeros and len(value_string) != len(striped_string):
return [system.digits[0], system.digits[int(striped_string)]]
else:
return [system.digits[int(striped_string)]]
# recursively record multiple digits
else:
result_unit = next(
u for u in reversed(system.units) if u.power < len(striped_string)
)
result_string = value_string[: -result_unit.power]
return (
get_value(result_string)
+ [result_unit]
+ get_value(striped_string[-result_unit.power :])
)
system = create_system(numbering_type)
int_dec = number_string.split(".")
if len(int_dec) == 1:
int_string = int_dec[0]
dec_string = ""
elif len(int_dec) == 2:
int_string = int_dec[0]
dec_string = int_dec[1]
else:
raise ValueError(
"invalid input num string with more than one dot: {}".format(number_string)
)
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
else:
result_symbols = [system.digits[int(c)] for c in int_string]
dec_symbols = [system.digits[int(c)] for c in dec_string]
if dec_string:
result_symbols += [system.math.point] + dec_symbols
if alt_two:
liang = CND(
2,
system.digits[2].alt_s,
system.digits[2].alt_t,
system.digits[2].big_s,
system.digits[2].big_t,
)
for i, v in enumerate(result_symbols):
if isinstance(v, CND) and v.value == 2:
next_symbol = (
result_symbols[i + 1] if i < len(result_symbols) - 1 else None
)
previous_symbol = result_symbols[i - 1] if i > 0 else None
if isinstance(next_symbol, CNU) and isinstance(
previous_symbol, (CNU, type(None))
):
if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol.power != 1)
):
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
if big:
attr_name = "big_"
if traditional:
attr_name += "t"
else:
attr_name += "s"
else:
if traditional:
attr_name = "traditional"
else:
attr_name = "simplified"
result = "".join([getattr(s, attr_name) for s in result_symbols])
# if not use_zeros:
# result = result.strip(getattr(system.digits[0], attr_name))
if alt_zero:
result = result.replace(
getattr(system.digits[0], attr_name), system.digits[0].alt_s
)
if alt_one:
result = result.replace(
getattr(system.digits[1], attr_name), system.digits[1].alt_s
)
for i, p in enumerate(POINT):
if result.startswith(p):
return CHINESE_DIGIS[0] + result
# ^10, 11, .., 19
if (
len(result) >= 2
and result[1]
in [
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
]
and result[0]
in [
CHINESE_DIGIS[1],
BIG_CHINESE_DIGIS_SIMPLIFIED[1],
BIG_CHINESE_DIGIS_TRADITIONAL[1],
]
):
result = result[1:]
return result
if __name__ == "__main__":
# 测试程序
all_chinese_number_string = (
CHINESE_DIGIS
+ BIG_CHINESE_DIGIS_SIMPLIFIED
+ BIG_CHINESE_DIGIS_TRADITIONAL
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ ZERO_ALT
+ ONE_ALT
+ "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
)
print("num:", chn2num("一万零四百零三点八零五"))
print("num:", chn2num("一亿六点三"))
print("num:", chn2num("一亿零六点三"))
print("num:", chn2num("两千零一亿六点三"))
# print('num:', chn2num('一零零八六'))
print("txt:", num2chn("10260.03", alt_zero=True))
print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
print(
"txt:",
num2chn(
"059523810880",
alt_one=True,
alt_two=False,
use_lzeros=True,
use_rzeros=True,
use_units=False,
),
)
print(all_chinese_number_string)
# -*- coding: utf-8 -*-
"""CARDINAL类 (包含小数DECIMAL类)
纯数 <=> 中文字符串 方法
中文字符串 <=> 纯数 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-03"
from fish_speech.text.chn_text_norm.basic_util import *
class Cardinal:
"""
CARDINAL类
"""
def __init__(self, cardinal=None, chntext=None):
self.cardinal = cardinal
self.chntext = chntext
def chntext2cardinal(self):
return chn2num(self.chntext)
def cardinal2chntext(self):
return num2chn(self.cardinal)
if __name__ == "__main__":
# 测试程序
print(Cardinal(cardinal="21357.230").cardinal2chntext())
# -*- coding: utf-8 -*-
"""DATE类
日期 <=> 中文字符串 方法
中文字符串 <=> 日期 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-07"
from fish_speech.text.chn_text_norm.cardinal import Cardinal
from fish_speech.text.chn_text_norm.digit import Digit
class Date:
"""
DATE类
"""
def __init__(self, date=None, chntext=None):
self.date = date
self.chntext = chntext
# def chntext2date(self):
# chntext = self.chntext
# try:
# year, other = chntext.strip().split('年', maxsplit=1)
# year = Digit(chntext=year).digit2chntext() + '年'
# except ValueError:
# other = chntext
# year = ''
# if other:
# try:
# month, day = other.strip().split('月', maxsplit=1)
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
# except ValueError:
# day = chntext
# month = ''
# if day:
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
# else:
# month = ''
# day = ''
# date = year + month + day
# self.date = date
# return self.date
def date2chntext(self):
date = self.date
try:
year, other = date.strip().split("年", maxsplit=1)
year = Digit(digit=year).digit2chntext() + "年"
except ValueError:
other = date
year = ""
if other:
try:
month, day = other.strip().split("月", maxsplit=1)
month = Cardinal(cardinal=month).cardinal2chntext() + "月"
except ValueError:
day = date
month = ""
if day:
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
else:
month = ""
day = ""
chntext = year + month + day
self.chntext = chntext
return self.chntext
if __name__ == "__main__":
# 测试
print(Date(date="09年3月16日").date2chntext())
# -*- coding: utf-8 -*-
"""DIGIT类
数字串 <=> 中文字符串 方法
中文字符串 <=> 数字串 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-03"
from fish_speech.text.chn_text_norm.basic_util import *
class Digit:
"""
DIGIT类
"""
def __init__(self, digit=None, chntext=None):
self.digit = digit
self.chntext = chntext
# def chntext2digit(self):
# return chn2num(self.chntext)
def digit2chntext(self):
return num2chn(self.digit, alt_two=False, use_units=False)
if __name__ == "__main__":
# 测试程序
print(Digit(digit="2016").digit2chntext())
# -*- coding: utf-8 -*-
"""FRACTION类
分数 <=> 中文字符串 方法
中文字符串 <=> 分数 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-03"
from fish_speech.text.chn_text_norm.basic_util import *
class Fraction:
"""
FRACTION类
"""
def __init__(self, fraction=None, chntext=None):
self.fraction = fraction
self.chntext = chntext
def chntext2fraction(self):
denominator, numerator = self.chntext.split("分之")
return chn2num(numerator) + "/" + chn2num(denominator)
def fraction2chntext(self):
numerator, denominator = self.fraction.split("/")
return num2chn(denominator) + "分之" + num2chn(numerator)
if __name__ == "__main__":
# 测试程序
print(Fraction(fraction="2135/7230").fraction2chntext())
print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
# -*- coding: utf-8 -*-
"""MONEY类
金钱 <=> 中文字符串 方法
中文字符串 <=> 金钱 方法
"""
import re
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-08"
from fish_speech.text.chn_text_norm.cardinal import Cardinal
class Money:
"""
MONEY类
"""
def __init__(self, money=None, chntext=None):
self.money = money
self.chntext = chntext
# def chntext2money(self):
# return self.money
def money2chntext(self):
money = self.money
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(money)
if matchers:
for matcher in matchers:
money = money.replace(
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
)
self.chntext = money
return self.chntext
if __name__ == "__main__":
# 测试
print(Money(money="21.5万元").money2chntext())
print(Money(money="230块5毛").money2chntext())
# -*- coding: utf-8 -*-
"""PERCENTAGE类
百分数 <=> 中文字符串 方法
中文字符串 <=> 百分数 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-06"
from fish_speech.text.chn_text_norm.basic_util import *
class Percentage:
"""
PERCENTAGE类
"""
def __init__(self, percentage=None, chntext=None):
self.percentage = percentage
self.chntext = chntext
def chntext2percentage(self):
return chn2num(self.chntext.strip().strip("百分之")) + "%"
def percentage2chntext(self):
return "百分之" + num2chn(self.percentage.strip().strip("%"))
if __name__ == "__main__":
# 测试程序
print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
print(Percentage(percentage="65.3%").percentage2chntext())
# -*- coding: utf-8 -*-
"""TELEPHONE类
电话号码 <=> 中文字符串 方法
中文字符串 <=> 电话号码 方法
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-03"
from fish_speech.text.chn_text_norm.basic_util import *
class TelePhone:
"""
TELEPHONE类
"""
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
self.telephone = telephone
self.raw_chntext = raw_chntext
self.chntext = chntext
# def chntext2telephone(self):
# sil_parts = self.raw_chntext.split('<SIL>')
# self.telephone = '-'.join([
# str(chn2num(p)) for p in sil_parts
# ])
# return self.telephone
def telephone2chntext(self, fixed=False):
if fixed:
sil_parts = self.telephone.split("-")
self.raw_chntext = "<SIL>".join(
[num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
)
self.chntext = self.raw_chntext.replace("<SIL>", "")
else:
sp_parts = self.telephone.strip("+").split()
self.raw_chntext = "<SP>".join(
[num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
)
self.chntext = self.raw_chntext.replace("<SP>", "")
return self.chntext
if __name__ == "__main__":
# 测试程序
print(TelePhone(telephone="0595-23980880").telephone2chntext())
# print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
# -*- coding: utf-8 -*-
"""
TEXT类
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-03"
import re
from fish_speech.text.chn_text_norm.cardinal import Cardinal
from fish_speech.text.chn_text_norm.date import Date
from fish_speech.text.chn_text_norm.digit import Digit
from fish_speech.text.chn_text_norm.fraction import Fraction
from fish_speech.text.chn_text_norm.money import Money
from fish_speech.text.chn_text_norm.percentage import Percentage
from fish_speech.text.chn_text_norm.telephone import TelePhone
CURRENCY_NAMES = (
"(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
"里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
)
CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
COM_QUANTIFIERS = (
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
)
class Text:
"""
Text类
"""
def __init__(self, raw_text, norm_text=None):
self.raw_text = "^" + raw_text + "$"
self.norm_text = norm_text
def _particular(self):
text = self.norm_text
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text)
if matchers:
# print('particular')
for matcher in matchers:
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
self.norm_text = text
return self.norm_text
def normalize(self):
text = self.raw_text
# 规范化日期
pattern = re.compile(
r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
)
matchers = pattern.findall(text)
if matchers:
# print('date')
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
# 规范化金钱
pattern = re.compile(
r"\D+((\d+(\.\d+)?)[多余几]?"
+ CURRENCY_UNITS
+ "(\d"
+ CURRENCY_UNITS
+ "?)?)"
)
matchers = pattern.findall(text)
if matchers:
# print('money')
for matcher in matchers:
text = text.replace(
matcher[0], Money(money=matcher[0]).money2chntext(), 1
)
# 规范化固话/手机号码
# 手机
# http://www.jihaoba.com/news/show/13680
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通:130、131、132、156、155、186、185、176
# 电信:133、153、189、180、181、177
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
# print('telephone')
for matcher in matchers:
text = text.replace(
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
for matcher in matchers:
text = text.replace(
matcher[0],
TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
1,
)
# 规范化分数
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
# print('fraction')
for matcher in matchers:
text = text.replace(
matcher, Fraction(fraction=matcher).fraction2chntext(), 1
)
# 规范化百分数
text = text.replace("%", "%")
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
# print('percentage')
for matcher in matchers:
text = text.replace(
matcher[0],
Percentage(percentage=matcher[0]).percentage2chntext(),
1,
)
# 规范化纯数+量词
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
# print('cardinal+quantifier')
for matcher in matchers:
text = text.replace(
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
)
# 规范化数字编号
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
# print('digit')
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
# 规范化纯数
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
# print('cardinal')
for matcher in matchers:
text = text.replace(
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
)
self.norm_text = text
self._particular()
return self.norm_text.lstrip("^").rstrip("$")
if __name__ == "__main__":
# 测试程序
print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
print(Text(raw_text="分数:32477/76391。").normalize())
print(Text(raw_text="百分数:80.03%。").normalize())
print(Text(raw_text="编号:31520181154418。").normalize())
print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
print(Text(raw_text="特殊:O2O或B2C。").normalize())
import re
SYMBOLS_MAPPING = {
"‘": "'",
"’": "'",
}
REPLACE_SYMBOL_REGEX = re.compile(
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
)
EMOJI_REGEX = re.compile(
"["
"\U0001F600-\U0001F64F" # emoticons
"\U0001F300-\U0001F5FF" # symbols & pictographs
"\U0001F680-\U0001F6FF" # transport & map symbols
"\U0001F1E0-\U0001F1FF" # flags (iOS)
"]+",
flags=re.UNICODE,
)
def clean_text(text):
# Clean the text
text = text.strip()
# Replace all chinese symbols with their english counterparts
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
# Remove emojis
text = EMOJI_REGEX.sub(r"", text)
# Remove continuous periods (...) and commas (,,,)
text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
return text
import re
import string
from fish_speech.text.clean import clean_text
def utf_8_len(text: str):
return len(text.encode("utf-8"))
def break_text(texts, length, splits: set):
for text in texts:
if utf_8_len(text) <= length:
yield text
continue
curr = ""
for char in text:
curr += char
if char in splits:
yield curr
curr = ""
if curr:
yield curr
def break_text_by_length(texts, length):
for text in texts:
if utf_8_len(text) <= length:
yield text
continue
curr = ""
for char in text:
curr += char
if utf_8_len(curr) >= length:
yield curr
curr = ""
if curr:
yield curr
def add_cleaned(curr, segments):
curr = curr.strip()
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
segments.append(curr)
def protect_float(text):
# Turns 3.14 into <3_f_14> to prevent splitting
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
def unprotect_float(text):
# Turns <3_f_14> into 3.14
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
def split_text(text, length):
text = clean_text(text)
# Break the text into pieces with following rules:
# 1. Split the text at ".", "!", "?" if text is NOT a float
# 2. If the text is longer than length, split at ","
# 3. If the text is still longer than length, split at " "
# 4. If the text is still longer than length, split at any character to length
texts = [text]
texts = map(protect_float, texts)
texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
texts = map(unprotect_float, texts)
texts = break_text(texts, length, {",", ","})
texts = break_text(texts, length, {" "})
texts = list(break_text_by_length(texts, length))
# Then, merge the texts into segments with length <= length
segments = []
curr = ""
for text in texts:
if utf_8_len(curr) + utf_8_len(text) <= length:
curr += text
else:
add_cleaned(curr, segments)
curr = text
if curr:
add_cleaned(curr, segments)
return segments
if __name__ == "__main__":
# Test the split_text function
text = "This is a test sentence. This is another test sentence. And a third one."
assert split_text(text, 50) == [
"This is a test sentence.",
"This is another test sentence. And a third one.",
]
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
assert split_text(" ", 10) == []
assert split_text("a", 10) == ["a"]
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
assert split_text(text, 50) == [
"This is a test sentence with only commas,",
"and no dots, and no exclamation marks,",
"and no question marks, and no newlines.",
]
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
# First half split at " ", second half split at ","
assert split_text(text, 50) == [
"This is a test sentence This is a test sentence",
"This is a test sentence. This is a test sentence,",
"This is a test sentence, This is a test sentence.",
]
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
assert split_text(text, 50) == [
"这是一段很长的中文文本,",
"而且没有句号,也没有感叹号,",
"也没有问号,也没有换行符.",
]
import base64
import json
import logging
from pathlib import Path
import tiktoken
logger = logging.getLogger(__name__)
# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
FISH_TIKTOKEN_PATTERN = "|".join(
[
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
r"\p{P}",
r"[^\r\n\p{L}\p{N}]?\p{L}+",
r"\p{N}",
r" ?[^\s\p{L}\p{N}]+[\r\n]*",
r"\s*[\r\n]+",
r"\s+(\?!\S)",
r"\s+",
]
)
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
BOS_TOKEN = "<|begin_of_text|>"
EOS_TOKEN = "<|end_of_text|>"
PAD_TOKEN = "<|pad|>"
IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
MODALITY_TEXT_TOKEN = "<|text|>"
MODALITY_VOICE_TOKEN = "<|voice|>"
MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
MODALITY_TOKENS = {
"text": MODALITY_TEXT_TOKEN,
"voice": MODALITY_VOICE_TOKEN,
"interleave": MODALITY_INTERLEAVE_TOKEN,
}
PLACEHOLDER_TOKEN = [""] * 4
for i in range(4):
PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
# Warning: when you add a new special token, you should only add it to the end of the list.
ALL_SPECIAL_TOKENS = [
BOS_TOKEN,
EOS_TOKEN,
PAD_TOKEN,
IM_START_TOKEN,
IM_END_TOKEN,
PLACEHOLDER_TOKEN[0],
PLACEHOLDER_TOKEN[1],
PLACEHOLDER_TOKEN[2],
PLACEHOLDER_TOKEN[3],
MODALITY_TEXT_TOKEN,
MODALITY_VOICE_TOKEN,
MODALITY_INTERLEAVE_TOKEN,
*SEMANTIC_TOKENS,
]
class FishTokenizer:
def __init__(self, model_path: str) -> None:
mergeable_ranks = self.load_tiktoken_bpe(model_path)
special_token_begin = len(mergeable_ranks)
self.all_special_tokens_with_ids = {
token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
}
self.semantic_id_to_token_id = {
i: self.all_special_tokens_with_ids[token]
for i, token in enumerate(SEMANTIC_TOKENS)
}
self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
self.tkt_model = tiktoken.core.Encoding(
name=Path(model_path).stem,
pat_str=FISH_TIKTOKEN_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=self.all_special_tokens_with_ids,
)
@staticmethod
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
data = {}
for line in open(tiktoken_bpe_file).read().splitlines():
if not line:
continue
token, rank = line.split()
data[base64.b64decode(token)] = int(rank)
return data
def get_token_id(self, token: str) -> int:
return self.all_special_tokens_with_ids[token]
def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
assert isinstance(s, str)
subs = []
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
if allowed_special is True:
allowed_special = self.tkt_model.special_tokens_set
elif allowed_special is False:
allowed_special = set()
return sum(
self.tkt_model.encode_batch(
subs, allowed_special=allowed_special, disallowed_special=set()
),
start=[],
)
def decode(self, tokens: list[int]) -> str:
return self.tkt_model.decode(tokens)
def save_pretrained(self, path: str):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
with open(path / "tokenizer.tiktoken", "w") as f:
for token, rank in self.tkt_model._mergeable_ranks.items():
f.write(f"{base64.b64encode(token).decode()} {rank}\n")
with open(path / "special_tokens.json", "w") as f:
json.dump(
self.all_special_tokens_with_ids,
f,
indent=2,
ensure_ascii=False,
)
@staticmethod
def from_pretrained(path: str):
return FishTokenizer(Path(path) / "tokenizer.tiktoken")
if __name__ == "__main__":
tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
print(
[
tokenizer.decode([i])
for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
]
)
import os
os.environ["USE_LIBUV"] = "0"
import sys
from typing import Optional
import hydra
import lightning as L
import pyrootutils
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies import DDPStrategy
from omegaconf import DictConfig, OmegaConf
os.environ.pop("SLURM_NTASKS", None)
os.environ.pop("SLURM_JOB_NAME", None)
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# Allow TF32 on Ampere GPUs
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.allow_tf32 = True
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
import fish_speech.utils as utils
log = utils.RankedLogger(__name__, rank_zero_only=True)
@utils.task_wrapper
def train(cfg: DictConfig) -> tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
""" # noqa: E501
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=False)
if cfg.get("deterministic"):
torch.use_deterministic_algorithms(True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
if cfg.get("train"):
log.info("Starting training!")
ckpt_path = cfg.get("ckpt_path")
auto_resume = False
resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
if resume_ckpt_path is not None:
ckpt_path = resume_ckpt_path
auto_resume = True
if ckpt_path is not None:
log.info(f"Resuming from checkpoint: {ckpt_path}")
# resume weights only is disabled for auto-resume
if cfg.get("resume_weights_only") and auto_resume is False:
log.info("Resuming weights only!")
ckpt = torch.load(ckpt_path, map_location=model.device)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
err = model.load_state_dict(ckpt, strict=False)
log.info(f"Error loading state dict: {err}")
ckpt_path = None
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = cfg.get("ckpt_path")
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(
version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
)
def main(cfg: DictConfig) -> Optional[float]:
# train the model
train(cfg)
if __name__ == "__main__":
main()
from .braceexpand import braceexpand
from .context import autocast_exclude_mps
from .file import get_latest_checkpoint
from .instantiators import instantiate_callbacks, instantiate_loggers
from .logger import RankedLogger
from .logging_utils import log_hyperparameters
from .rich_utils import enforce_tags, print_config_tree
from .utils import extras, get_metric_value, set_seed, task_wrapper
__all__ = [
"enforce_tags",
"extras",
"get_metric_value",
"RankedLogger",
"instantiate_callbacks",
"instantiate_loggers",
"log_hyperparameters",
"print_config_tree",
"task_wrapper",
"braceexpand",
"get_latest_checkpoint",
"autocast_exclude_mps",
"set_seed",
]
"""
Bash-style brace expansion
Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
License: MIT
"""
import re
import string
from itertools import chain, product
from typing import Iterable, Iterator, Optional
__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
class UnbalancedBracesError(ValueError):
pass
alphabet = string.ascii_uppercase + string.ascii_lowercase
int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
escape_re = re.compile(r"\\(.)")
def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
"""braceexpand(pattern) -> iterator over generated strings
Returns an iterator over the strings resulting from brace expansion
of pattern. This function implements Brace Expansion as described in
bash(1), with the following limitations:
* A pattern containing unbalanced braces will raise an
UnbalancedBracesError exception. In bash, unbalanced braces will either
be partly expanded or ignored.
* A mixed-case character range like '{Z..a}' or '{a..Z}' will not
include the characters '[]^_`' between 'Z' and 'a'.
When escape is True (the default), characters in pattern can be
prefixed with a backslash to cause them not to be interpreted as
special characters for brace expansion (such as '{', '}', ',').
To pass through a a literal backslash, double it ('\\\\').
When escape is False, backslashes in pattern have no special
meaning and will be preserved in the output.
Examples:
>>> from braceexpand import braceexpand
# Integer range
>>> list(braceexpand('item{1..3}'))
['item1', 'item2', 'item3']
# Character range
>>> list(braceexpand('{a..c}'))
['a', 'b', 'c']
# Sequence
>>> list(braceexpand('index.html{,.backup}'))
['index.html', 'index.html.backup']
# Nested patterns
>>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
# Prefixing an integer with zero causes all numbers to be padded to
# the same width.
>>> list(braceexpand('{07..10}'))
['07', '08', '09', '10']
# An optional increment can be specified for ranges.
>>> list(braceexpand('{a..g..2}'))
['a', 'c', 'e', 'g']
# Ranges can go in both directions.
>>> list(braceexpand('{4..1}'))
['4', '3', '2', '1']
# Numbers can be negative
>>> list(braceexpand('{2..-1}'))
['2', '1', '0', '-1']
# Unbalanced braces raise an exception.
>>> list(braceexpand('{1{2,3}'))
Traceback (most recent call last):
...
UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
# By default, the backslash is the escape character.
>>> list(braceexpand(r'{1\\{2,3}'))
['1{2', '3']
# Setting 'escape' to False disables backslash escaping.
>>> list(braceexpand(r'\\{1,2}', escape=False))
['\\\\1', '\\\\2']
"""
return (
escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
)
def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
start = 0
pos = 0
bracketdepth = 0
items: list[Iterable[str]] = []
# print 'pattern:', pattern
while pos < len(pattern):
if escape and pattern[pos] == "\\":
pos += 2
continue
elif pattern[pos] == "{":
if bracketdepth == 0 and pos > start:
# print 'literal:', pattern[start:pos]
items.append([pattern[start:pos]])
start = pos
bracketdepth += 1
elif pattern[pos] == "}":
bracketdepth -= 1
if bracketdepth == 0:
# print 'expression:', pattern[start+1:pos]
expr = pattern[start + 1 : pos]
item = parse_expression(expr, escape)
if item is None: # not a range or sequence
items.extend([["{"], parse_pattern(expr, escape), ["}"]])
else:
items.append(item)
start = pos + 1 # skip the closing brace
pos += 1
if bracketdepth != 0: # unbalanced braces
raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
if start < pos:
items.append([pattern[start:]])
return ("".join(item) for item in product(*items))
def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
int_range_match = int_range_re.match(expr)
if int_range_match:
return make_int_range(*int_range_match.groups())
char_range_match = char_range_re.match(expr)
if char_range_match:
return make_char_range(*char_range_match.groups())
return parse_sequence(expr, escape)
def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
# sequence -> chain(*sequence_items)
start = 0
pos = 0
bracketdepth = 0
items: list[Iterable[str]] = []
# print 'sequence:', seq
while pos < len(seq):
if escape and seq[pos] == "\\":
pos += 2
continue
elif seq[pos] == "{":
bracketdepth += 1
elif seq[pos] == "}":
bracketdepth -= 1
elif seq[pos] == "," and bracketdepth == 0:
items.append(parse_pattern(seq[start:pos], escape))
start = pos + 1 # skip the comma
pos += 1
if bracketdepth != 0:
raise UnbalancedBracesError
if not items:
return None
# part after the last comma (may be the empty string)
items.append(parse_pattern(seq[start:], escape))
return chain(*items)
def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
padding = max(len(left), len(right))
else:
padding = 0
step = (int(incr) or 1) if incr else 1
start = int(left)
end = int(right)
r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
fmt = "%0{}d".format(padding)
return (fmt % i for i in r)
def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
step = (int(incr) or 1) if incr else 1
start = alphabet.index(left)
end = alphabet.index(right)
if start < end:
return alphabet[start : end + 1 : step]
else:
end = end or -len(alphabet)
return alphabet[start : end - 1 : -step]
if __name__ == "__main__":
import doctest
import sys
failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
if failed:
sys.exit(1)
from contextlib import nullcontext
import torch
def autocast_exclude_mps(
device_type: str, dtype: torch.dtype
) -> nullcontext | torch.autocast:
return (
nullcontext()
if torch.backends.mps.is_available()
else torch.autocast(device_type, dtype)
)
import os
from pathlib import Path
from typing import Union
from loguru import logger
from natsort import natsorted
AUDIO_EXTENSIONS = {
".mp3",
".wav",
".flac",
".ogg",
".m4a",
".wma",
".aac",
".aiff",
".aif",
".aifc",
}
VIDEO_EXTENSIONS = {
".mp4",
".avi",
}
def get_latest_checkpoint(path: Path | str) -> Path | None:
# Find the latest checkpoint
ckpt_dir = Path(path)
if ckpt_dir.exists() is False:
return None
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
if len(ckpts) == 0:
return None
return ckpts[-1]
def audio_to_bytes(file_path):
if not file_path or not Path(file_path).exists():
return None
with open(file_path, "rb") as wav_file:
wav = wav_file.read()
return wav
def read_ref_text(ref_text):
path = Path(ref_text)
if path.exists() and path.is_file():
with path.open("r", encoding="utf-8") as file:
return file.read()
return ref_text
def list_files(
path: Union[Path, str],
extensions: set[str] = set(),
recursive: bool = False,
sort: bool = True,
) -> list[Path]:
"""List files in a directory.
Args:
path (Path): Path to the directory.
extensions (set, optional): Extensions to filter. Defaults to None.
recursive (bool, optional): Whether to search recursively. Defaults to False.
sort (bool, optional): Whether to sort the files. Defaults to True.
Returns:
list: List of files.
"""
if isinstance(path, str):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Directory {path} does not exist.")
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
if sort:
files = natsorted(files)
return files
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
"""
Load a Bert-VITS2 style filelist.
"""
files = set()
results = []
count_duplicated, count_not_found = 0, 0
LANGUAGE_TO_LANGUAGES = {
"zh": ["zh", "en"],
"jp": ["jp", "en"],
"en": ["en"],
}
with open(path, "r", encoding="utf-8") as f:
for line in f.readlines():
splits = line.strip().split("|", maxsplit=3)
if len(splits) != 4:
logger.warning(f"Invalid line: {line}")
continue
filename, speaker, language, text = splits
file = Path(filename)
language = language.strip().lower()
if language == "ja":
language = "jp"
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
languages = LANGUAGE_TO_LANGUAGES[language]
if file in files:
logger.warning(f"Duplicated file: {file}")
count_duplicated += 1
continue
if not file.exists():
logger.warning(f"File not found: {file}")
count_not_found += 1
continue
results.append((file, speaker, languages, text))
if count_duplicated > 0:
logger.warning(f"Total duplicated files: {count_duplicated}")
if count_not_found > 0:
logger.warning(f"Total files not found: {count_not_found}")
return results
from typing import List
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Callback
from pytorch_lightning.loggers import Logger
from .logger import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config."""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config."""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger
import logging
from typing import Mapping, Optional
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
class RankedLogger(logging.LoggerAdapter):
"""A multi-GPU-friendly python command line logger."""
def __init__(
self,
name: str = __name__,
rank_zero_only: bool = True,
extra: Optional[Mapping[str, object]] = None,
) -> None:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
with their rank prefixed in the log message.
:param name: The name of the logger. Default is ``__name__``.
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
"""
logger = logging.getLogger(name)
super().__init__(logger=logger, extra=extra)
self.rank_zero_only = rank_zero_only
def log(
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
occur on that rank/process.
:param level: The level to log at. Look at `logging.__init__.py` for more information.
:param msg: The message to log.
:param rank: The rank to log at.
:param args: Additional args to pass to the underlying logging function.
:param kwargs: Any additional keyword args to pass to the underlying logging function.
"""
if self.isEnabledFor(level):
msg, kwargs = self.process(msg, kwargs)
current_rank = getattr(rank_zero_only, "rank", None)
if current_rank is None:
raise RuntimeError(
"The `rank_zero_only.rank` needs to be set before use"
)
msg = rank_prefixed_message(msg, current_rank)
if self.rank_zero_only:
if current_rank == 0:
self.logger.log(level, msg, *args, **kwargs)
else:
if rank is None:
self.logger.log(level, msg, *args, **kwargs)
elif current_rank == rank:
self.logger.log(level, msg, *args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment