Commit e6e33f1a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2698 canceled with stages
# MIT License
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
# This modified file is released under the same license.
"""Convolutional layers wrappers and utilities."""
import math
import typing as tp
import warnings
import einops
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_layer_norm', 'layer_norm', 'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'layer_norm':
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class ConvLayerNorm(nn.LayerNorm):
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = einops.rearrange(x, 'b ... t -> b t ...')
x = super().forward(x)
x = einops.rearrange(x, 'b t ... -> b ... t')
return
class NormConv1d(nn.Module):
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class SConv1d(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect'):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
return self.conv(x)
\ No newline at end of file
# MIT License
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
# This modified file is released under the same license.
"""LSTM layers module."""
from torch import nn
class SLSTM(nn.Module):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
# 修改transpose顺序
def forward(self, x):
x1 = x.permute(2, 0, 1)
y, _ = self.lstm(x1)
y = y.permute(1, 2, 0)
if self.skip:
y = y + x
return y
# MIT License
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
# This modified file is released under the same license.
"""Encodec SEANet-based encoder and decoder implementation."""
import typing as tp
import numpy as np
import torch.nn as nn
from .conv import SConv1d
from .lstm import SLSTM
class SEANetResnetBlock(nn.Module):
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
super().__init__()
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
act = getattr(nn, activation)
mult = 1
model: tp.List[nn.Module] = [
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
norm=norm, norm_params=norm_params,
activation=activation, activation_params=activation_params,
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
# Add downsampling layers
model += [
act(**activation_params),
SConv1d(mult * n_filters, mult * n_filters * 2,
kernel_size=ratio * 2, stride=ratio,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
mult *= 2
if lstm:
model += [SLSTM(mult * n_filters, num_layers=lstm)]
model += [
act(**activation_params),
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):
is_torch = isinstance(mel2token, torch.Tensor)
has_batch_dim = True
if not is_torch:
mel2token = torch.LongTensor(mel2token)
if T_txt is None:
T_txt = mel2token.max()
if len(mel2token.shape) == 1:
mel2token = mel2token[None, ...]
has_batch_dim = False
B, _ = mel2token.shape
dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token))
dur = dur[:, 1:]
if max_dur is not None:
dur = dur.clamp(max=max_dur)
if not is_torch:
dur = dur.numpy()
if not has_batch_dim:
dur = dur[0]
return dur
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import subprocess
import numpy as np
from scipy.io import wavfile
import pyloudnorm as pyln
from pydub import AudioSegment
def to_wav_bytes(wav, sr, norm=False):
wav = wav.astype(float)
if norm:
meter = pyln.Meter(sr) # create BS.1770 meter
loudness = meter.integrated_loudness(wav)
wav = pyln.normalize.loudness(wav, loudness, -18.0)
if np.abs(wav).max() >= 1:
wav = wav / np.abs(wav).max() * 0.95
wav = wav * 32767
bytes_io = io.BytesIO()
wavfile.write(bytes_io, sr, wav.astype(np.int16))
return bytes_io.getvalue()
def save_wav(wav_bytes, path):
with open(path[:-4] + '.wav', 'wb') as file:
file.write(wav_bytes)
if path[-4:] == '.mp3':
to_mp3(path[:-4])
def to_mp3(out_path):
if out_path[-4:] == '.wav':
out_path = out_path[:-4]
subprocess.check_call(
f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"',
shell=True, stdin=subprocess.PIPE)
subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True)
def convert_to_wav(wav_path):
# Check if the file exists
if not os.path.exists(wav_path):
print(f"The file '{wav_path}' does not exist.")
return
# Check if the file already has a .wav extension
if not wav_path.endswith(".wav"):
# Define the output path with a .wav extension
out_path = os.path.splitext(wav_path)[0] + ".wav"
# Load the audio file using pydub and convert it to WAV
audio = AudioSegment.from_file(wav_path)
audio.export(out_path, format="wav")
print(f"Converted '{wav_path}' to '{out_path}'")
def convert_to_wav_bytes(audio_binary):
# Load the audio binary using pydub and convert it to WAV
audio = AudioSegment.from_file(io.BytesIO(audio_binary))
wav_bytes = io.BytesIO()
audio.export(wav_bytes, format="wav")
wav_bytes.seek(0)
return wav_bytes
''' Smoothly combine audio segments using crossfade transitions." '''
def combine_audio_segments(segments, crossfade_duration=0.16, sr=24000):
window_length = int(sr * crossfade_duration)
hanning_window = np.hanning(2 * window_length)
# Combine
for i, segment in enumerate(segments):
if i == 0:
combined_audio = segment
else:
overlap = combined_audio[-window_length:] * hanning_window[window_length:] + segment[:window_length] * hanning_window[:window_length]
combined_audio = np.concatenate(
[combined_audio[:-window_length], overlap, segment[window_length:]]
)
return combined_audio
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy']
def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None, figsize=(12, 6)):
if isinstance(spec, torch.Tensor):
spec = spec.cpu().numpy()
H = spec.shape[1] // 2
fig = plt.figure(figsize=figsize)
plt.title(title)
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
if dur_info is not None:
assert isinstance(dur_info, dict)
txt = dur_info['txt']
dur_gt = dur_info['dur_gt']
if isinstance(dur_gt, torch.Tensor):
dur_gt = dur_gt.cpu().numpy()
dur_gt = np.cumsum(dur_gt).astype(int)
for i in range(len(dur_gt)):
shift = (i % 8) + 1
plt.text(dur_gt[i], shift * 4, txt[i])
plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt
plt.xlim(0, dur_gt[-1])
if 'dur_pred' in dur_info:
dur_pred = dur_info['dur_pred']
if isinstance(dur_pred, torch.Tensor):
dur_pred = dur_pred.cpu().numpy()
dur_pred = np.cumsum(dur_pred).astype(int)
for i in range(len(dur_pred)):
shift = (i % 8) + 1
plt.text(dur_pred[i], H + shift * 4, txt[i])
plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred
plt.xlim(0, max(dur_gt[-1], dur_pred[-1]))
if f0s is not None:
ax = plt.gca()
ax2 = ax.twinx()
# ax.set_xticks()
if not isinstance(f0s, dict):
f0s = {'f0': f0s}
for i, (k, f0) in enumerate(f0s.items()):
if f0 is not None:
if isinstance(f0, torch.Tensor):
f0 = f0.cpu().numpy()
ax2.plot(
np.arange(len(f0)) + 0.5, f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5)
ax2.set_ylim(0, 1000)
ax2.legend()
return fig
def align_to_figure(align, dur_info):
if isinstance(align, torch.Tensor):
align = align.cpu().numpy()
H = align.shape[1]
fig = plt.figure(figsize=(12, 6))
plt.pcolor(align.T, vmin=0, vmax=1)
if dur_info is not None:
assert isinstance(dur_info, dict)
txt = dur_info['txt']
dur_gt = dur_info['dur_gt']
if isinstance(dur_gt, torch.Tensor):
dur_gt = dur_gt.cpu().numpy()
dur_gt = np.cumsum(dur_gt).astype(int) // 2
for i in range(len(dur_gt)):
plt.text(dur_gt[i], i, txt[i], color='red')
plt.vlines(dur_gt[i], 0, H, colors='b') # blue is gt
# plt.xlim(0, dur_gt[-1])
return fig
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import glob
import os
import re
import subprocess
import traceback
import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
@contextlib.contextmanager
def dist_load(path):
if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
yield path
else:
from tts.utils.commons.hparams import hparams
from tts.utils.commons.trainer import LOCAL_RANK
tmpdir = '/dev/shm'
assert len(os.path.basename(path)) > 0
shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
if LOCAL_RANK == 0:
subprocess.check_call(
f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
f'cp -Lr {path} {shm_ckpt_path}', shell=True)
dist.barrier()
yield shm_ckpt_path
dist.barrier()
if LOCAL_RANK == 0:
subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
def torch_load_dist(path, map_location='cpu'):
with dist_load(path) as tmp_path:
checkpoint = torch.load(tmp_path, map_location=map_location)
return checkpoint
def get_last_checkpoint(work_dir, steps=None):
checkpoint = None
last_ckpt_path = None
ckpt_paths = get_all_ckpts(work_dir, steps)
if len(ckpt_paths) > 0:
last_ckpt_path = ckpt_paths[0]
checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
return checkpoint, last_ckpt_path
def get_all_ckpts(work_dir, steps=None):
if steps is None or steps == 0:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
else:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
return sorted(glob.glob(ckpt_path_pattern),
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
if checkpoint is None:
if os.path.isfile(ckpt_base_dir):
base_dir = os.path.dirname(ckpt_base_dir)
ckpt_path = ckpt_base_dir
checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
else:
base_dir = ckpt_base_dir
if load_opt:
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
else:
ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
if os.path.exists(ckpt_path):
checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
else:
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
if checkpoint is not None:
state_dict_all = {
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
if not isinstance(cur_model, list):
cur_models = [cur_model]
model_names = [model_name]
else:
cur_models = cur_model
model_names = model_name
for model_name, cur_model in zip(model_names, cur_models):
if isinstance(cur_model, DistributedDataParallel):
cur_model = cur_model.module
device = next(cur_model.parameters()).device
if '.' not in model_name:
state_dict = state_dict_all[model_name]
else:
base_model_name = model_name.split('.')[0]
rest_model_name = model_name[len(base_model_name) + 1:]
state_dict = {
k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
if k.startswith(f'{rest_model_name}.')}
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
if not strict and delete_unmatch:
try:
cur_model.load_state_dict(state_dict, strict=True)
if not silent:
print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
except:
cur_model_state_dict = cur_model.state_dict()
cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
cur_model_state_dict.items()}
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
"ckpt model: ", param.shape)
for key in unmatched_keys:
del state_dict[key]
load_results = cur_model.load_state_dict(state_dict, strict=strict)
cur_model.to(device)
if not silent:
print(f"| loaded '{model_name}' from '{ckpt_path}'.")
missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
if load_opt:
optimizer_states = checkpoint['optimizer_states']
assert len(opts) == len(optimizer_states)
for optimizer, opt_state in zip(opts, optimizer_states):
opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
if optimizer is None:
return
try:
optimizer.load_state_dict(opt_state)
for i, state in enumerate(optimizer.state.values()):
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
except ValueError:
print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
return checkpoint.get('global_step', 0)
else:
e_msg = f"| ckpt not found in {base_dir}."
if force:
assert False, e_msg
else:
print(e_msg)
def load_with_size_mismatch(model, state_dict, prefix=""):
current_model_dict = model.state_dict()
cm_keys = current_model_dict.keys()
mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print(f"| mismatch keys: ", mismatch_keys)
if len(missing_keys) > 0:
print(f"| missing_keys in dit: {missing_keys}")
if len(unexpected_keys) > 0:
print(f"| unexpected_keys in dit: {unexpected_keys}")
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import yaml
global_print_hparams = True
hparams = {}
class Args:
def __init__(self, **kwargs):
for k, v in kwargs.items():
self.__setattr__(k, v)
def override_config(old_config: dict, new_config: dict):
if new_config.get('__replace', False):
old_config.clear()
for k, v in new_config.items():
if isinstance(v, dict) and k in old_config:
override_config(old_config[k], new_config[k])
else:
old_config[k] = v
def traverse_dict(d, func, ctx):
for k in list(d.keys()):
v = d[k]
if isinstance(v, dict):
traverse_dict(v, func, ctx)
else:
d[k] = func(v, ctx)
def parse_config(v, context=None):
if context is None:
context = {}
if isinstance(v, str):
if v.startswith('^'):
return load_config(v[1:], [], set())
match = re.match(r"\${(.*)}", v)
if match:
expression = match.group(1)
return eval(expression, {}, context)
return v
def remove_meta_key(d):
for k in list(d.keys()):
v = d[k]
if isinstance(v, dict):
remove_meta_key(v)
else:
if k[:2] == '__':
del d[k]
def load_config(config_fn, config_chains, loaded_configs):
# deep first inheritance and avoid the second visit of one node
if not os.path.exists(config_fn):
print(f"| WARN: {config_fn} not exist.", )
return {}
with open(config_fn) as f:
hparams_ = yaml.safe_load(f)
loaded_configs.add(config_fn)
if 'base_config' in hparams_:
ret_hparams = {}
if not isinstance(hparams_['base_config'], list):
hparams_['base_config'] = [hparams_['base_config']]
for c in hparams_['base_config']:
if c.startswith('.'):
c = f'{os.path.dirname(config_fn)}/{c}'
c = os.path.normpath(c)
if c not in loaded_configs:
override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
override_config(ret_hparams, hparams_)
else:
ret_hparams = hparams_
config_chains.append(config_fn)
return ret_hparams
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
if config == '' and exp_name == '':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--config', type=str, default='',
help='location of the data corpus')
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
parser.add_argument('-hp', '--hparams', type=str, default='',
help='location of the data corpus')
parser.add_argument('--infer', action='store_true', help='infer')
parser.add_argument('--validate', action='store_true', help='validate')
parser.add_argument('--reset', action='store_true', help='reset hparams')
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
parser.add_argument('--debug', action='store_true', help='debug')
parser.add_argument('--start_rank', type=int, default=-1,
help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
parser.add_argument('--world_size', type=int, default=-1,
help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
parser.add_argument('--master_addr', type=str, default='', help='')
parser.add_argument('--ddp_dir', type=str, default='', help='')
args, unknown = parser.parse_known_args()
if print_hparams:
print("| set_hparams Unknow hparams: ", unknown)
else:
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
infer=False, validate=False, reset=False, debug=False, remove=False,
start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='')
global hparams
assert args.config != '' or args.exp_name != ''
if args.config != '':
assert os.path.exists(args.config), f"{args.config} not exists"
saved_hparams = {}
args_work_dir = ''
if args.exp_name != '':
args_work_dir = f'{args.exp_name}'
ckpt_config_path = f'{args_work_dir}/config.yaml'
if os.path.exists(ckpt_config_path):
with open(ckpt_config_path) as f:
saved_hparams_ = yaml.safe_load(f)
if saved_hparams_ is not None:
saved_hparams.update(saved_hparams_)
hparams_ = {}
config_chains = []
if args.config != '':
hparams_.update(load_config(args.config, config_chains, set()))
if len(config_chains) > 1 and print_hparams:
print('| Hparams chains: ', config_chains)
if not args.reset:
hparams_.update(saved_hparams)
traverse_dict(hparams_, parse_config, hparams_)
hparams_['work_dir'] = args_work_dir
# Support config overriding in command line. Support list type config overriding.
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
if args.hparams != "":
for new_hparam in args.hparams.split(","):
k, v = new_hparam.split("=")
v = v.strip("\'\" ")
config_node = hparams_
for k_ in k.split(".")[:-1]:
config_node = config_node[k_]
k = k.split(".")[-1]
if k in config_node:
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
if type(config_node[k]) == list:
v = v.replace(" ", ",").replace('^', "\"")
if '|' in v:
tp = type(config_node[k][0]) if len(config_node[k]) else str
config_node[k] = [tp(x) for x in v.split("|") if x != '']
continue
config_node[k] = eval(v)
else:
config_node[k] = type(config_node[k])(v)
else:
config_node[k] = v
try:
config_node[k] = float(v)
except:
pass
try:
config_node[k] = int(v)
except:
pass
if v.lower() in ['false', 'true']:
config_node[k] = v.lower() == 'true'
if args_work_dir != '' and not args.infer:
os.makedirs(hparams_['work_dir'], exist_ok=True)
hparams_['infer'] = args.infer
hparams_['debug'] = args.debug
hparams_['validate'] = args.validate
hparams_['exp_name'] = args.exp_name
hparams_['start_rank'] = args.start_rank # useful for multi-machine training
hparams_['world_size'] = args.world_size
hparams_['init_method'] = args.init_method
hparams_['ddp_dir'] = args.ddp_dir
hparams_['master_addr'] = args.master_addr
remove_meta_key(hparams_)
global global_print_hparams
if global_hparams:
hparams.clear()
hparams.update(hparams_)
if print_hparams and global_print_hparams and global_hparams:
print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
# for i, (k, v) in enumerate(sorted(hparams_.items())):
# print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
global_print_hparams = False
return hparams_
\ No newline at end of file
{"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]}
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
def map_phone_to_tokendict(item, pad_bos_eos=True):
# Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations.
phone = item['txt_token'].clone()
merged_phone = item['txt_token'].clone()
tone_tmp = item['tone'].clone()
# In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15
tone_tmp[tone_tmp==4] = 1
tone_tmp[tone_tmp==11] = 2
tone_tmp[tone_tmp==12] = 3
tone_tmp[tone_tmp==13] = 4
tone_tmp[tone_tmp==14] = 5
tone_tmp[tone_tmp==15] = 6
# Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788
ch_phone_idx = (phone >= 3) & (phone <= 100)
merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx]
if pad_bos_eos:
merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798)
merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799)
return merged_phone
def split_ph_timestamp(ph_timestamp):
''' Input: ph_timestamp, shape [T] '''
# Map the timestamp of each phone back to its original frame-level lengths
ph_timestamp[ph_timestamp >= 800] -= 800
ph_list = []
tone_list = []
dur_list = []
cur_timestamp = 0
for idx, item in enumerate(ph_timestamp):
if idx % 2 == 0:
# Map Chinese phones back to its original phone_dict
if (200 <= item <= 788):
ph = (item - 200 - 1) // 6 + 3
tone = (item - 200 - 1) % 6 + 1
if tone == 1:
tone = 4
else:
tone = tone + 9
# Set English tone to '3'
else:
ph = item
tone = 3
ph_list.append(ph)
tone_list.append(tone)
else:
dur_list.append((item - cur_timestamp))
cur_timestamp = item
assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}"
ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list)
return ph_seq, tone_seq, dur_seq, ph_timestamp[-1]
def split_ph(ph_seq):
''' Input: ph_timestamp, shape [T] '''
ph_list = []
tone_list = []
for idx, item in enumerate(ph_seq):
# Map Chinese phones back to its original phone_dict
if (200 <= item <= 788):
ph = (item - 200 - 1) // 6 + 3
tone = (item - 200 - 1) % 6 + 1
if tone == 1:
tone = 4
else:
tone = tone + 9
# Set English tone to '3'
else:
ph = item
tone = 3
ph_list.append(ph)
tone_list.append(tone)
assert len(ph_list) == len(tone_list)
ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list)
return ph_seq, tone_seq
\ No newline at end of file
# -*- coding: utf-8 -*-
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def chunk_text_chinese(text, limit=60):
# 中文字符匹配
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
# 标点符号匹配
punctuation = ",。!?;:,\.!?;"
result = [] # 存储断句结果
current_chunk = [] # 当前片段
chinese_count = 0 # 中文字符计数
i = 0
while i < len(text):
char = text[i]
current_chunk.append(char)
if chinese_pattern.match(char):
chinese_count += 1
if chinese_count >= limit: # 达到限制字符数
# 从当前位置往前找最近的标点符号
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in punctuation:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
break
else:
# 如果前面没有标点符号,则继续找后面的标点符号
for k in range(i + 1, len(text)):
if text[k] in punctuation:
result.append(''.join(current_chunk)+text[i+1:k+1])
current_chunk = []
chinese_count = 0
i = k
break
i+=1
# 添加最后剩余的部分
if current_chunk:
result.append(''.join(current_chunk))
return result
def chunk_text_english(text, max_chars=130):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def chunk_text_chinesev2(text, limit=60, look_ahead_limit=30):
"""
将中文文本分成多个块,优先确保每个块以句号、感叹号或问号结尾,
其次考虑逗号等其他标点符号,避免在无标点处断句
参数:
text: 要分块的文本
limit: 每个块的中文字符数限制
look_ahead_limit: 向后查找的最大字符数限制
返回:
分块后的文本列表
"""
# 中文字符匹配
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
# 分级定义标点符号(优先级从高到低)
primary_end_marks = "。.!!??" # 首选:句号、感叹号、问号
secondary_end_marks = ",,;;:" # 次选:逗号、分号、冒号
tertiary_end_marks = "、…—-~~" # 再次:顿号、省略号、破折号等
result = [] # 存储断句结果
current_chunk = [] # 当前片段
chinese_count = 0 # 中文字符计数
i = 0
while i < len(text):
char = text[i]
current_chunk.append(char)
if chinese_pattern.match(char):
chinese_count += 1
if chinese_count >= limit: # 达到字符数限制,需要寻找断句点
found_end = False
# 依次尝试不同优先级的断句策略
# 1. 向后查找首选标点
for k in range(1, min(look_ahead_limit, len(text) - i)):
next_char = text[i + k]
if next_char in primary_end_marks:
result.append(''.join(current_chunk) + text[i+1:i+k+1])
current_chunk = []
chinese_count = 0
i = i + k
found_end = True
break
if not found_end:
# 2. 向前查找首选标点
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in primary_end_marks:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
found_end = True
break
if not found_end:
# 3. 向后查找次选标点
for k in range(1, min(look_ahead_limit, len(text) - i)):
next_char = text[i + k]
if next_char in secondary_end_marks:
result.append(''.join(current_chunk) + text[i+1:i+k+1])
current_chunk = []
chinese_count = 0
i = i + k
found_end = True
break
if not found_end:
# 4. 向前查找次选标点
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in secondary_end_marks:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
found_end = True
break
if not found_end:
# 5. 向后查找三级标点
for k in range(1, min(look_ahead_limit, len(text) - i)):
next_char = text[i + k]
if next_char in tertiary_end_marks:
result.append(''.join(current_chunk) + text[i+1:i+k+1])
current_chunk = []
chinese_count = 0
i = i + k
found_end = True
break
if not found_end:
# 6. 向前查找三级标点
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in tertiary_end_marks:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
found_end = True
break
if not found_end:
# 万不得已,在此处断句(这种情况很少见,因为汉语文本中通常会有标点)
result.append(''.join(current_chunk))
current_chunk = []
chinese_count = 0
i += 1
# 添加最后剩余的部分
if current_chunk:
result.append(''.join(current_chunk))
return result
if __name__ == '__main__':
print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。"))
print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day."))
text = "欢迎收听《TED Talks Daily》,在这里,我们每天为您带来新思想,激发您的好奇心。我是您的主持人,Elise Hugh。当我们去看医生时,医生会评估我们的身体健康状况,检查我们的生命体征,可能还会关注我们的胆固醇水平,确保我们整体处于健康状态。医生可能还会通过一系列问题来检查我们的心理健康。然而,人际交往专家Casley Killam指出,我们在理解健康时忽略了一个关键指标,那就是我们的社会健康。在2024年的演讲中,她解释了为什么人际关系如此重要,以及忽视它可能带来的代价。几年前,我认识的一位女士,我们暂且称她为Maya,在短时间内经历了许多重大变化。她结婚了,和丈夫因工作搬到了一个陌生的城市,在那里她谁也不认识。她开始了一份在家办公的新工作,同时还要应对父亲新确诊的痴呆症。为了应对这些变化带来的压力,Maya加倍关注自己的身心健康。她几乎每天都锻炼,吃健康的食物,每周去看一次心理医生。这些措施确实有帮助,她的身体变得更加强壮,心理也更具韧性,但效果有限。她仍然感到困扰,经常在半夜失眠,白天感到注意力不集中,缺乏动力。Maya做了医生通常建议我们做的所有事情来保持身心健康,但似乎还缺少些什么。如果我告诉你,Maya所缺少的东西,也是全球数十亿人所缺少的,甚至可能也是你所缺少的呢?如果我告诉你,缺乏它会削弱我们为保持健康所做的其他努力,甚至可能缩短你的寿命呢?我研究这个问题已经超过十年,我发现,我们传统上对健康的理解是不完整的。通过将健康主要视为身体和心理的健康,我们忽略了我认为是我们这个时代最大的挑战和机遇——社会健康。身体健康关乎我们的身体,心理健康关乎我们的思想,而社会健康则关乎我们的人际关系。如果你以前没有听说过这个词,那是因为它还没有进入主流词汇,但它同样重要。Maya在她的新家还没有归属感。她不再亲自见到她的家人、朋友或同事,她经常一连几周只和丈夫共度时光。她的故事告诉我们,如果我们只照顾身体和心理,而不关注人际关系,我们就无法完全健康,无法真正茁壮成长。与Maya类似,全球有数亿人连续几周不与任何朋友或家人交谈。全球范围内,有四分之一的人感到孤独。20%的成年人觉得他们没有任何人可以求助。想想看,你遇到的每五个人中,可能有一个人觉得自己孤立无援。这不仅令人心碎,也是一场公共卫生危机。"
for res in chunk_text_chinesev2(text):
print(res)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment