Commit ab9c00af authored by yangzhong's avatar yangzhong
Browse files

init submission

parents
Pipeline #3176 failed with stages
in 0 seconds
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple
import torch
from torch import nn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0 : # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x, zero_triu: bool = False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Positonal Encoding Module."""
import math
from typing import Tuple, Union
import torch
import torch.nn.functional as F
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int = 5000,
reverse: bool = False):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.max_len = max_len
pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int, torch.tensor): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
self.pe = self.pe.to(x.device)
pos_emb = self.position_encoding(offset, x.size(1), False)
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
apply_dropout: bool = True) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
assert offset + size < self.max_len
pos_emb = self.pe[:, offset:offset + size]
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
assert offset + size < self.max_len
pos_emb = self.pe[:, offset:offset + size]
else: # for batched streaming decoding on GPU
assert torch.max(offset) + size < self.max_len
index = offset.unsqueeze(1) + \
torch.arange(0, size).to(offset.device) # B X T
flag = index > 0
# remove negative offset
index = index * flag
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
if apply_dropout:
pos_emb = self.dropout(pos_emb)
return pos_emb
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Initialize class."""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.position_encoding(offset, x.size(1), False)
return self.dropout(x), self.dropout(pos_emb)
class NoPositionalEncoding(torch.nn.Module):
""" No position encoding
"""
def __init__(self, d_model: int, dropout_rate: float):
super().__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
""" Just return zero vector for interface compatibility
"""
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
return self.dropout(x), pos_emb
def position_encoding(
self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
return torch.zeros(1, size, self.d_model)
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from typing import Tuple, Union
import torch
class BaseSubsampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv2dSubsampling3(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/3 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling3 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 5, 3),
torch.nn.ReLU()
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * ((idim - 2) // 3), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 3
# 4 = (5 - 1) * 1
self.right_context = 4
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 3.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 3.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:3]
class Conv2dSubsampling2(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * ((idim - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 2
# 2 = (3 - 1) * 1
self.right_context = 2
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2]
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
class Conv2dSubsampling6(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling6 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
odim)
self.pos_enc = pos_enc_class
# 10 = (3 - 1) * 1 + (5 - 1) * 2
self.subsampling_rate = 6
self.right_context = 10
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
class Conv2dSubsampling8(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling8 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.pos_enc = pos_enc_class
self.subsampling_rate = 8
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
self.right_context = 14
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
from typing import Optional, Tuple
import torch
import torch.nn as nn
from indextts.gpt.conformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from indextts.gpt.conformer.embedding import (NoPositionalEncoding,
PositionalEncoding,
RelPositionalEncoding)
from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
LinearNoSubsampling)
from indextts.utils.common import make_pad_mask
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU()):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
class ConformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: Optional[nn.Module] = None,
feed_forward_macaron: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size,
eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
size, eps=1e-5) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, pos_emb, att_cache)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache
class BaseEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
concat_after: bool = False,
):
"""
Args:
input_size (int): input dim
output_size (int): dimension of attention
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of decoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after (bool): whether to concat attention layer's input
and output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
"""
super().__init__()
self._output_size = output_size
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
subsampling_class = LinearNoSubsampling
elif input_layer == "conv2d2":
subsampling_class = Conv2dSubsampling2
elif input_layer == "conv2d":
subsampling_class = Conv2dSubsampling4
elif input_layer == "conv2d6":
subsampling_class = Conv2dSubsampling6
elif input_layer == "conv2d8":
subsampling_class = Conv2dSubsampling8
else:
raise ValueError("unknown input_layer: " + input_layer)
self.embed = subsampling_class(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, dropout_rate),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.embed(xs, masks)
chunk_masks = masks
mask_pad = masks # (B, 1, T/subsample_rate)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "rel_pos",
normalize_before: bool = True,
concat_after: bool = False,
macaron_style: bool = False,
use_cnn_module: bool = True,
cnn_module_kernel: int = 15,
):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
"""
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
concat_after)
activation = torch.nn.SiLU()
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
encoder_selfattn_layer = MultiHeadedAttention
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
dropout_rate,
)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size,
cnn_module_kernel,
activation,)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(
*positionwise_layer_args) if macaron_style else None,
convolution_layer(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
) for _ in range(num_blocks)
])
from typing import (Any, Dict, Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, TypeVar, Union, Sequence)
import numpy as np
import torch
from torch import nn
from transformers import BatchFeature
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix,
merge_multimodal_embeddings
)
from vllm.model_executor.models.gpt2 import GPT2Block #, GPT2MLP, GPT2Attention
from vllm.model_executor.models.interfaces import SupportsMultiModal, MultiModalEmbeddings
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement,
BaseProcessingInfo, PromptInsertion,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.parse import (MultiModalDataParser, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems)
# from vllm.model_executor.models.utils import merge_multimodal_embeddings
PLACEHOLDER_TOKEN = "!"
PLACEHOLDER_TOKEN_ID = 0
class GPT2TTSProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# 声明我们支持 'audio' 模态
return {"audio": None}
class GPT2TTSDummyInputsBuilder(BaseDummyInputsBuilder[GPT2TTSProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return PLACEHOLDER_TOKEN * num_audios
def get_dummy_mm_data(self, seq_len: int, mm_counts: Mapping[str, int]) -> Dict[str, Any]:
num_items = mm_counts.get("audio", 0)
if num_items == 0:
return {}
config = self.info.get_hf_config()
dummy_seq_len = 1024
dummy_embed = torch.rand(
(dummy_seq_len, config.n_embd),
dtype=torch.float16,
)
return {"audio": {"audio_embeds": [dummy_embed] * num_items}}
class GPT2TTSDataParser(MultiModalDataParser):
"""
这个解析器重写了处理 'audio' 模态的方法。
"""
def _parse_audio_data(
self,
data: Union[Dict[str, torch.Tensor], Any],
) -> Optional[ModalityDataItems[Any, Any]]:
"""
当 vLLM 看到 "audio" 这个 key 时,会调用这个函数。
'data' 参数是 "audio" key 对应的值。
"""
# 期望的值是一个字典,例如 {"audio_embeds": tensor}
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
# 这个工厂函数告诉 vLLM 如何将字典里的键映射到模型 forward 函数的参数
# 这里将 "audio_embeds" 映射到名为 "audio_embeds" 的参数
fields_factory=lambda hf_inputs: dict(
audio_embeds=MultiModalFieldConfig.batched("audio")
),
)
# 如果传入了 "audio" 但不是期望的字典格式,就报错
raise TypeError(f"For 'audio' modality, expected a dict like {'{'} 'audio_embeds': tensor {'}'}, but got {type(data)}")
class GPT2TTSMultiModalProcessor(BaseMultiModalProcessor[GPT2TTSProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
def _get_data_parser(self) -> MultiModalDataParser:
return GPT2TTSDataParser()
def _get_prompt_updates(
self,
mm_items: "MultiModalDataItems",
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> List[PromptUpdate]:
out_mm_data = out_mm_kwargs.get_data()
def get_replacement(item_idx: int):
# 从处理过的数据中根据 'audio_embeds' 键获取 embedding
embeds = out_mm_data["audio_embeds"][item_idx]
num_features = embeds.shape[0] # 获取序列长度
# 创建一个假的 token 序列,长度必须正确
return PromptUpdateDetails.select_token_id(
[PLACEHOLDER_TOKEN_ID] * num_features, PLACEHOLDER_TOKEN_ID
)
return [
PromptReplacement(
modality="audio",
target=PLACEHOLDER_TOKEN, # [PLACEHOLDER_TOKEN_ID],
replacement=get_replacement,
)
]
@support_torch_compile
class GPT2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.n_embd
# self.wte = VocabParallelEmbedding(config.vocab_size,
# self.embed_dim,
# quant_config=quant_config,
# prefix=f"{prefix}.wte")
# self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer,
lambda prefix: GPT2Block(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]:
# if get_pp_group().is_first_rank:
# if inputs_embeds is None:
# inputs_embeds = self.get_input_embeddings(input_ids)
# position_embeds = self.wpe(position_ids)
# hidden_states = inputs_embeds + position_embeds
# else:
# assert intermediate_tensors is not None
# hidden_states = intermediate_tensors["hidden_states"]
hidden_states = inputs_embeds
for layer in self.h[self.start_layer:self.end_layer]:
hidden_states = layer(hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
@MULTIMODAL_REGISTRY.register_processor(GPT2TTSMultiModalProcessor,
info=GPT2TTSProcessingInfo,
dummy_inputs=GPT2TTSDummyInputsBuilder)
class GPT2TTSModel(nn.Module, SupportsPP, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.text_pos_embedding = LearnedPositionEmbeddings(self.config.n_positions, self.config.n_embd)
with torch.no_grad():
self.text_pos_embedding.emb.weight[0].zero_()
self.audio_emb = nn.Embedding(self.config.vocab_size, self.config.n_embd)
self.final_norm = nn.LayerNorm(self.config.n_embd, bias=True)
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.n_embd,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
bias=True)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
# 实现 SupportsMultiModal 接口方法
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self,
**kwargs: object,
) -> MultiModalEmbeddings:
# 从 kwargs 中提取我们的 embedding
audio_embeds = kwargs.get("audio_embeds")
processed_embeds = []
for embed in audio_embeds:
# 检查是否是多余的维度为1的3D张量
if embed.dim() == 3 and embed.shape[0] == 1:
# 移除多余的批次维度,使其变为 2D
processed_embeds.append(embed.squeeze(0))
elif embed.dim() == 2:
# 如果已经是 2D 张量,直接添加
processed_embeds.append(embed)
else:
# 对于非预期的维度,可以抛出错误以便调试
raise ValueError(
"Expected audio embeddings to be 2D or 3D with a "
f"leading dimension of 1, but got shape: {embed.shape}")
return processed_embeds
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# # 这个方法现在用于合并文本和多模态 embedding
# # 在我们的 prefill 场景下,input_ids 是假的,我们只关心 multimodal_embeddings
# if multimodal_embeddings is not None: # and len(multimodal_embeddings) > 0
# # 假设只有一个多模态输入,并且它就是我们想要的完整 embedding
# # 如果有多个,需要将它们拼接起来
# # 注意:vLLM 的 merge_multimodal_embeddings 是用于替换占位符 token 的,
# # 而我们的场景是整个输入都是 embedding,所以我们直接返回它。
# return torch.cat(multimodal_embeddings, dim=0)
# # 对于 decode 阶段,我们走正常的 embedding lookup
# return self.audio_emb(input_ids)
inputs_embeds = self.audio_emb(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
PLACEHOLDER_TOKEN_ID)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
# **kwargs 用于接收 get_multimodal_embeddings 的数据
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
# assert inputs_embeds is not None
positions = torch.clamp(positions, min=0)
pos_emb = self.text_pos_embedding.emb(positions)
# kusuriuri: 这里必须使用 += ,否则计算结果会错误
inputs_embeds += pos_emb
transformer_output = self.transformer(
input_ids=None,
position_ids=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds
)
# if get_pp_group().is_last_rank:
transformer_output = self.final_norm(transformer_output)
return transformer_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name or ".attn.masked_bias" in name:
continue
if ".wte" in name:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
# try:
weight_loader(param, loaded_weight)
# except:
# print("weight_loader", name)
# raise AssertionError()
loaded_params.add(name)
# 确保在加载权重后,第0个位置的embedding仍然是全零向量。
with torch.no_grad():
self.text_pos_embedding.emb.weight[0].zero_()
return loaded_params
\ No newline at end of file
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2LMHeadModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)
from transformers import GPT2Config, GPT2Model
from indextts.gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.perceiver import PerceiverResampler
from indextts.utils.arch_util import AttentionBlock
from indextts.utils.typical_sampling import TypicalLogitsWarper
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class GPT2InferenceModel(GPT2LMHeadModel):
def __init__(self, config, gpt, text_pos_emb, audio_emb, norm, linear):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.audio_emb = audio_emb
self.final_norm = norm
self.lm_head = linear
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
assert input_ids is None or input_ids.shape[1] == 1
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# # Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids is not None: # and input_ids.shape[1] == 1
inputs_embeds = self.audio_emb(input_ids)
inputs_embeds = inputs_embeds + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - mel_len, attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
if torch.backends.mps.is_available():
self.to(self.transformer.first_device)
else:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(self.final_norm(hidden_states))
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.last_hidden_state,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
types=1, activation_function=None,
condition_num_latent=32, condition_module=None, **kwargs):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
checkpointing:
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.cond_num = condition_num_latent
self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
self.conditioning_encoder = ConformerEncoder(input_size=100,
output_size=condition_module['output_size'],
linear_units=condition_module['linear_units'],
attention_heads=condition_module['attention_heads'],
num_blocks=condition_module['num_blocks'],
input_layer=condition_module['input_layer'])
self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
ff_mult=condition_module['perceiver_mult'],
heads=condition_module['attention_heads'],
num_latents=self.cond_num)
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
max_mel_seq_len = self.max_mel_tokens + 2 + self.max_conditioning_inputs
max_text_seq_len = self.max_text_tokens + 2
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
activation_function=activation_function or "gelu_new",
gradient_checkpointing=False,
use_cache=True)
self.gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del self.gpt.wte
self.mel_pos_embedding, self.text_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim)
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding]
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False, activation_function=None):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.number_mel_codes,
n_positions=self.max_mel_tokens + 2 + self.max_conditioning_inputs,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
activation_function=activation_function or "gelu_new",
use_cache=True,
bos_token_id=self.start_mel_token,
eos_token_id=self.stop_mel_token,
)
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head,
# kv_cache=kv_cache,
)
self.inference_model = self.inference_model.eval()
# self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
conds_mask = self.cond_mask_pad(mask.squeeze(1))
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
return conds
def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
# speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
# conds = speech_conditioning_latent
emb = torch.cat([speech_conditioning_latent, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
trunc_index = emb.shape[1] + 1
mel_start_emb = self.mel_embedding(torch.full((emb.shape[0], 1,), fill_value=self.start_mel_token, dtype=torch.long, device=text_inputs.device))
mel_start_emb = mel_start_emb + self.mel_pos_embedding(mel_start_emb)
inputs_embeds = torch.cat([emb, mel_start_emb], dim=1)
logits_processor = LogitsProcessorList()
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
gen = self.inference_model.generate(inputs_embeds=inputs_embeds, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token,
return_dict_in_generate=True, output_hidden_states=True,
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
codes = gen.sequences[:, 1:]
latent = torch.cat(gen.hidden_states, dim=1)
latent = latent[:, trunc_index:-1]
latent = self.final_norm(latent)
return codes, latent # [:, trunc_index:]
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
# import transformers
# from transformers import GPT2Config, LogitsProcessorList
# from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)
from indextts.gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.perceiver import PerceiverResampler
from indextts.utils.arch_util import AttentionBlock
from indextts.utils.typical_sampling import TypicalLogitsWarper
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan // 8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan // 8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
super().__init__(config)
# Note: the argument named `text_pos_emb` here actually represents the mel position embedding
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.audio_emb = embeddings
self.final_norm = norm
self.lm_head = linear # nn.Sequential(norm, linear)
self.kv_cache = kv_cache
# Model parallel
self.model_parallel = False
self.device_map = None
self.cached_mel_emb = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:]
text_emb = self.audio_emb(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
)
else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.audio_emb(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - mel_len, attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
if torch.backends.mps.is_available():
self.to(self.transformer.first_device)
else:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(self.final_norm(hidden_states))
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past
)
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False,
mean=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
self.mean = mean
def forward(self, x):
h = self.init(x)
h = self.attn(h)
if self.mean:
return h.mean(dim=2)
else:
return h
# return h[:, :, 0]
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del gpt.wpe
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
None, None
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0, 2, 1)
class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1,
condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing:
condition_type: perceiver, gst or default encoder
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.condition_type = condition_type
self.cond_num = condition_num_latent
self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
self.emo_cond_mask_pad = nn.ConstantPad1d((1, 0), True)
# use conformer_perceiver
self.conditioning_encoder = ConformerEncoder(input_size=1024,
output_size=condition_module['output_size'],
linear_units=condition_module['linear_units'],
attention_heads=condition_module['attention_heads'],
num_blocks=condition_module['num_blocks'],
input_layer=condition_module['input_layer'])
self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
ff_mult=condition_module['perceiver_mult'],
heads=condition_module['attention_heads'],
num_latents=self.cond_num)
self.emo_conditioning_encoder = ConformerEncoder(input_size=1024,
output_size=emo_condition_module['output_size'],
linear_units=emo_condition_module['linear_units'],
attention_heads=emo_condition_module['attention_heads'],
num_blocks=emo_condition_module['num_blocks'],
input_layer=emo_condition_module['input_layer'])
self.emo_perceiver_encoder = PerceiverResampler(1024, dim_context=emo_condition_module['output_size'],
ff_mult=emo_condition_module['perceiver_mult'],
heads=emo_condition_module['attention_heads'],
num_latents=1)
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
self.emo_layer = nn.Linear(model_dim, model_dim)
self.emovec_layer = nn.Linear(1024, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
self.max_text_tokens + 2, checkpointing)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.speed_emb = nn.Embedding(2, model_dim)
self.speed_emb.weight.data.normal_(mean=0.0, std=0.0)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]
if use_mel_codes_as_input:
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.number_mel_codes,
n_positions=self.max_mel_tokens + 2 + self.max_conditioning_inputs, # seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=self.start_mel_token,
eos_token_id=self.stop_mel_token,
)
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head,
kv_cache=kv_cache,
)
if use_deepspeed and half and torch.cuda.is_available():
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float16)
self.inference_model = self.ds_engine.module.eval()
elif use_deepspeed and torch.cuda.is_available():
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float32)
self.inference_model = self.ds_engine.module.eval()
else:
self.inference_model = self.inference_model.eval()
# self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, mel_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
for b in range(len(mel_lengths)):
# Due to the convolutional nature of how these tokens are generated,
# it would be best if the model predicts a token past the actual last token.
actual_end = mel_lengths[b]
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def set_text_padding(self, text_input_tokens, text_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
for b in range(len(text_lengths)):
# Due to the convolutional nature of how these tokens are generated,
# it would be best if the model predicts a token past the actual last token.
actual_end = text_lengths[b]
if actual_end < text_input_tokens.shape[-1]:
text_input_tokens[b, actual_end:] = self.stop_text_token
return text_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
offset = speech_conditioning_inputs.shape[1]
enc = gpt_out.last_hidden_state[:, offset:]
enc = self.final_norm(enc)
if return_latent:
return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0, 2, 1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0, 2, 1)
return first_logits, second_logits
else:
return first_logits
def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
# conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
conds_mask = self.cond_mask_pad(mask.squeeze(1))
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
return conds
def get_emo_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
speech_conditioning_input, mask = self.emo_conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
conds_mask = self.emo_cond_mask_pad(mask.squeeze(1))
conds = self.emo_perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 1, d)
return conds.squeeze(1)
def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, mel_codes_lengths, emo_speech_conditioning_latent,
cond_mel_lengths=None, emo_cond_mel_lengths=None, emo_vec=None, use_speed=None, do_spk_cond=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
speech_conditioning_input: MEL float tensor, (b,1024)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
"""
if do_spk_cond:
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent.transpose(1,2), cond_mel_lengths)
else:
speech_conditioning_latent = speech_conditioning_latent
if emo_vec is None:
emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_mel_lengths)
emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
emo_vec = self.emo_layer(emo_vec_syn)
text_inputs = self.set_text_padding(text_inputs, text_lengths)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
duration_emb = self.speed_emb(torch.zeros_like(use_speed))
duration_emb_half = self.speed_emb(torch.ones_like(use_speed))
conds = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_emb = self.mel_embedding(mel_codes)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=False, return_latent=True)
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
def prepare_gpt_inputs(
self,
conditional_latents: torch.Tensor,
text_inputs: torch.Tensor,
):
"""
Prepare the inputs for the GPT2InferenceModel to generate.
Args:
conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
text_inputs: (b, L)
Returns:
input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
"""
b, L = text_inputs.shape[:2]
device = text_inputs.device
single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
if not single_cond:
assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
batched_mel_emb = []
attention_masks = []
target_len = conditional_latents.shape[1] + L + 2
for i in range(b):
valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
text_input = text_inputs[i][valid_mask]
text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
text_input_pos = torch.arange(0, text_input.size(-1), device=device)
text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
# concatenate [conditional latents][text embeddings]
conds_text_emb = [
conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
text_emb,
]
# +1 for the start_mel_token
attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
# check this text input is padded
padding: int = L + 2 - text_input.size(-1)
# pad left of [cond][text] -> [pad][cond][text]
if padding > 0:
pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
conds_text_emb.insert(0, pad)
attention_mask[:padding] = 0
mel_emb = torch.cat(conds_text_emb) #[s, dim]
assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
batched_mel_emb.append(mel_emb)
attention_masks.append(attention_mask)
# [b, s, dim]
batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
# [b, s+1]
attention_mask = torch.stack(attention_masks, dim=0)
# [b, s+1]
fake_inputs = torch.ones(
(
batched_mel_emb.shape[0],
batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
),
dtype=torch.long,
device=device,
)
fake_inputs[:, -1] = self.start_mel_token
return fake_inputs, batched_mel_emb, attention_mask
def inference_speech(self, speech_condition, text_inputs, emo_speech_condition=None, cond_lengths=None, emo_cond_lengths=None, emo_vec=None, use_speed=False, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
"""
Args:
speech_condition: (b, d, frames) or (d, frames)
text_inputs: (b, L)
cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
input_tokens: additional tokens for generation in shape (b, s) or (s,)
max_generate_length: limit the number of generated tokens
hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
"""
if speech_condition.ndim == 2:
speech_condition = speech_condition.unsqueeze(0)
if emo_speech_condition is None:
emo_speech_condition = speech_condition
if cond_lengths is None:
cond_lengths = torch.tensor([speech_condition.shape[-1]], device=speech_condition.device)
if emo_cond_lengths is None:
emo_cond_lengths = torch.tensor([emo_speech_condition.shape[-1]], device=speech_condition.device)
speech_conditioning_latent = self.get_conditioning(speech_condition.transpose(1,2), cond_lengths)
if emo_vec is None:
print('compute emo vec')
emo_vec = self.get_emo_conditioning(emo_speech_condition.transpose(1,2), emo_cond_lengths)
emo_vec = self.emovec_layer(emo_vec)
emo_vec = self.emo_layer(emo_vec)
else:
print('Use the specified emotion vector')
tmp = torch.zeros(text_inputs.size(0)).to(text_inputs.device)
duration_emb = self.speed_emb(torch.zeros_like(tmp).long())
duration_emb_half = self.speed_emb(torch.ones_like(tmp).long())
conds_latent = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
self.inference_model.store_mel_emb(inputs_embeds)
if input_tokens is None:
inputs = input_ids
else:
if input_tokens.ndim == 1:
input_tokens = input_tokens.unsqueeze(0)
assert num_return_sequences % input_tokens.shape[0] == 0, \
"The num_return_sequences must be divisible by the batch number of input_tokens"
assert num_return_sequences % text_inputs.shape[0] == 0, \
"The num_return_sequences must be divisible by the batch number of text_inputs"
b = num_return_sequences // input_ids.shape[0]
if b > 1:
input_ids = input_ids.repeat(b, 1)
attention_mask = attention_mask.repeat(b, 1)
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
inputs = torch.cat([input_ids, input_tokens], dim=1)
attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
trunc_index = inputs.shape[1]
logits_processor = LogitsProcessorList()
if typical_sampling:
# employ custom typical sampling
if not (typical_mass > 0.0 and typical_mass < 1.0):
raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
output = self.inference_model.generate(inputs,
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences,
**hf_generate_kwargs)
if isinstance(output, torch.Tensor):
return output[:, trunc_index:], speech_conditioning_latent
# GenerateOutput
output.sequences = output.sequences[:, trunc_index:]
return output, speech_conditioning_latent
def get_emovec(self, emo_speech_conditioning_latent, emo_cond_lengths):
emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_lengths)
emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
emo_vec = self.emo_layer(emo_vec_syn)
return emo_vec
def merge_emovec(self, speech_conditioning_latent, emo_speech_conditioning_latent, cond_lengths, emo_cond_lengths, alpha = 1.0):
emo_vec = self.get_emovec(emo_speech_conditioning_latent, emo_cond_lengths)
base_vec = self.get_emovec(speech_conditioning_latent, cond_lengths)
out = base_vec + alpha * (emo_vec - base_vec)
return out
import uuid
import os
import functools
import patch_vllm # ⚠️ Monkey Patch, do not delete this line
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, GPT2Config, GPT2LMHeadModel, LogitsProcessorList
from transformers import GPT2Config, GPT2Model
from indextts.gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.perceiver import PerceiverResampler
from indextts.gpt.index_tts_gpt2_vllm_v1 import PLACEHOLDER_TOKEN, PLACEHOLDER_TOKEN_ID
from vllm import AsyncLLMEngine, SamplingParams, TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
class UnifiedVoice(nn.Module):
def __init__(self, vllm_model,
layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
types=1, activation_function=None,
model_dir=None,
condition_num_latent=32, condition_module=None, **kwargs):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
checkpointing:
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.cond_num = condition_num_latent
self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
self.conditioning_encoder = ConformerEncoder(input_size=100,
output_size=condition_module['output_size'],
linear_units=condition_module['linear_units'],
attention_heads=condition_module['attention_heads'],
num_blocks=condition_module['num_blocks'],
input_layer=condition_module['input_layer'])
self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
ff_mult=condition_module['perceiver_mult'],
heads=condition_module['attention_heads'],
num_latents=self.cond_num)
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
max_mel_seq_len = self.max_mel_tokens + 2 + self.max_conditioning_inputs
max_text_seq_len = self.max_text_tokens + 2
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
activation_function=activation_function or "gelu_new",
gradient_checkpointing=False,
use_cache=True)
self.gpt = GPT2Model(gpt_config)
# self.gpt = AutoModelForCausalLM.from_pretrained(
# os.path.join(model_dir, "gpt"),
# # torch_dtype=torch.float16,
# # device_map="auto",
# # trust_remote_code=True, # 若自定义模型类需打开
# ).transformer
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
del self.gpt.wte
# self.gpt = self.gpt.to("cuda") # .to(torch.float16)
# self.gpt.eval()
self.mel_pos_embedding, self.text_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim)
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) # , dtype=torch.float16
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding]
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
self.llm: AsyncLLM = vllm_model
self.sampling_params = SamplingParams(
temperature=1.0,
top_p=0.8,
top_k=30, # 5, 30
repetition_penalty=10.0, # 8.0
max_tokens=768, # 605
)
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
conds_mask = self.cond_mask_pad(mask.squeeze(1))
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
return conds
async def inference_speech(self, speech_conditioning_latent, text_inputs, cond_mel_lengths=None):
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
# speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
emb = torch.cat([speech_conditioning_latent, text_emb], dim=1)
# trunc_index = emb.shape[1] + 1
mel_start_emb = self.mel_embedding(torch.full((emb.shape[0], 1,), fill_value=self.start_mel_token, dtype=torch.long, device=text_inputs.device))
mel_start_emb = mel_start_emb + self.mel_pos_embedding(mel_start_emb)
inputs_embeds = torch.cat([emb, mel_start_emb], dim=1)
# fake_inputs = [idx for idx in range(inputs_embeds.shape[1])]
fake_inputs = PLACEHOLDER_TOKEN * 1 # [PLACEHOLDER_TOKEN_ID]
multi_modal_data = {"audio": {"audio_embeds": [inputs_embeds.squeeze(0).cpu()]}}
tokens_prompt = TokensPrompt(prompt=fake_inputs, multi_modal_data=multi_modal_data)
# tokens_prompt = TokensPrompt(prompt_token_ids=fake_inputs, multi_modal_data=multi_modal_data)
output_generator = self.llm.generate(tokens_prompt, sampling_params=self.sampling_params, request_id=uuid.uuid4().hex)
# latent = []
async for output in output_generator:
# latent.append(output.hidden_states.clone())
pass
codes = output.outputs[0].token_ids[:-2]
# latent = torch.cat(latent[:-2], dim=0).unsqueeze(0)
# # latent = self.final_norm(latent.float())
# latent = latent.float()
# print("codes", len(codes), codes)
# print("latent", latent.shape, latent)
return codes # , latent
def set_mel_padding(self, mel_input_tokens, mel_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
for b in range(len(mel_lengths)):
# Due to the convolutional nature of how these tokens are generated,
# it would be best if the model predicts a token past the actual last token.
actual_end = mel_lengths[b]
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def set_text_padding(self, text_input_tokens, text_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
for b in range(len(text_lengths)):
# Due to the convolutional nature of how these tokens are generated,
# it would be best if the model predicts a token past the actual last token.
actual_end = text_lengths[b]
if actual_end < text_input_tokens.shape[-1]:
text_input_tokens[b, actual_end:] = self.stop_text_token
return text_input_tokens
def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths,
cond_mel_lengths=None, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=True, clip_inputs=False):
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
text_inputs = self.set_text_padding(text_inputs, text_lengths)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
conds = speech_conditioning_latent
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
emb = torch.cat([conds, text_emb, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
offset = conds.shape[1]
enc = gpt_out.last_hidden_state[:, offset:]
enc = self.final_norm(enc)
return enc[:, -mel_emb.shape[1]:][:, :-2]
import time
import uuid
import os
import functools
from loguru import logger
import patch_vllm # ⚠️ Monkey Patch, do not delete this line
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2LMHeadModel, LogitsProcessorList
from transformers import GPT2Config, GPT2Model
from indextts.gpt.conformer_encoder import ConformerEncoder
from indextts.gpt.perceiver import PerceiverResampler
from indextts.gpt.index_tts_gpt2_vllm_v1 import PLACEHOLDER_TOKEN, PLACEHOLDER_TOKEN_ID
from vllm import AsyncLLMEngine, SamplingParams, TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
class UnifiedVoice(nn.Module):
def __init__(self, vllm_model,
layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1,
condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
checkpointing:
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.cond_num = condition_num_latent
self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
self.emo_cond_mask_pad = nn.ConstantPad1d((1, 0), True)
self.conditioning_encoder = ConformerEncoder(input_size=1024,
output_size=condition_module['output_size'],
linear_units=condition_module['linear_units'],
attention_heads=condition_module['attention_heads'],
num_blocks=condition_module['num_blocks'],
input_layer=condition_module['input_layer'])
self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
ff_mult=condition_module['perceiver_mult'],
heads=condition_module['attention_heads'],
num_latents=self.cond_num)
self.emo_conditioning_encoder = ConformerEncoder(input_size=1024,
output_size=emo_condition_module['output_size'],
linear_units=emo_condition_module['linear_units'],
attention_heads=emo_condition_module['attention_heads'],
num_blocks=emo_condition_module['num_blocks'],
input_layer=emo_condition_module['input_layer'])
self.emo_perceiver_encoder = PerceiverResampler(1024, dim_context=emo_condition_module['output_size'],
ff_mult=emo_condition_module['perceiver_mult'],
heads=emo_condition_module['attention_heads'],
num_latents=1)
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
self.emo_layer = nn.Linear(model_dim, model_dim)
self.emovec_layer = nn.Linear(1024, model_dim)
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
max_mel_seq_len = self.max_mel_tokens + 2 + self.max_conditioning_inputs
max_text_seq_len = self.max_text_tokens + 2
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=False,
use_cache=True)
self.gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del self.gpt.wte
self.mel_pos_embedding, self.text_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim)
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) # , dtype=torch.float16
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.speed_emb = nn.Embedding(2, model_dim)
self.speed_emb.weight.data.normal_(mean=0.0, std=0.0)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding]
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
self.llm: AsyncLLM = vllm_model
self.sampling_params = SamplingParams(
temperature=1.0,
top_p=0.8,
top_k=30, # 5, 30
repetition_penalty=10.0, # 8.0
max_tokens=2048, # 605
stop_token_ids=[self.stop_mel_token],
include_stop_str_in_output=True,
)
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, mel_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
for b in range(len(mel_lengths)):
# Due to the convolutional nature of how these tokens are generated,
# it would be best if the model predicts a token past the actual last token.
actual_end = mel_lengths[b]
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def set_text_padding(self, text_input_tokens, text_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
for b in range(len(text_lengths)):
# Due to the convolutional nature of how these tokens are generated,
# it would be best if the model predicts a token past the actual last token.
actual_end = text_lengths[b]
if actual_end < text_input_tokens.shape[-1]:
text_input_tokens[b, actual_end:] = self.stop_text_token
return text_input_tokens
def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
conds_mask = self.cond_mask_pad(mask.squeeze(1))
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
return conds
def get_emo_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
speech_conditioning_input, mask = self.emo_conditioning_encoder(speech_conditioning_input.transpose(1, 2),
cond_mel_lengths) # (b, s, d), (b, 1, s)
conds_mask = self.emo_cond_mask_pad(mask.squeeze(1))
conds = self.emo_perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 1, d)
return conds.squeeze(1)
async def inference_speech(self, speech_condition, text_inputs, emo_speech_condition=None, cond_lengths=None, emo_cond_lengths=None, emo_vec=None, use_speed=False):
if speech_condition.ndim == 2:
speech_condition = speech_condition.unsqueeze(0)
if emo_speech_condition is None:
emo_speech_condition = speech_condition
if cond_lengths is None:
cond_lengths = torch.tensor([speech_condition.shape[-1]], device=speech_condition.device)
if emo_cond_lengths is None:
emo_cond_lengths = torch.tensor([emo_speech_condition.shape[-1]], device=speech_condition.device)
speech_conditioning_latent = self.get_conditioning(speech_condition.transpose(1,2), cond_lengths)
if emo_vec is None:
logger.info('compute emo vec')
emo_vec = self.get_emo_conditioning(emo_speech_condition.transpose(1,2), emo_cond_lengths)
emo_vec = self.emovec_layer(emo_vec)
emo_vec = self.emo_layer(emo_vec)
else:
logger.info('Use the specified emotion vector')
tmp = torch.zeros(text_inputs.size(0)).to(text_inputs.device)
duration_emb = self.speed_emb(torch.zeros_like(tmp).long())
duration_emb_half = self.speed_emb(torch.ones_like(tmp).long())
conds_latent = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
emb = torch.cat([conds_latent, text_emb], dim=1)
mel_start_emb = self.mel_embedding(torch.full((emb.shape[0], 1,), fill_value=self.start_mel_token, dtype=torch.long, device=text_inputs.device))
mel_start_emb = mel_start_emb + self.mel_pos_embedding(mel_start_emb)
inputs_embeds = torch.cat([emb, mel_start_emb], dim=1)
fake_inputs = PLACEHOLDER_TOKEN * 1 # [PLACEHOLDER_TOKEN_ID]
multi_modal_data = {"audio": {"audio_embeds": [inputs_embeds.squeeze(0).cpu()]}}
tokens_prompt = TokensPrompt(prompt=fake_inputs, multi_modal_data=multi_modal_data)
# tokens_prompt = TokensPrompt(prompt_token_ids=fake_inputs, multi_modal_data=multi_modal_data)
request_id = uuid.uuid4().hex
output_generator = self.llm.generate(tokens_prompt, sampling_params=self.sampling_params, request_id=request_id)
gpt_stt = time.time()
prefill_flag = True
async for output in output_generator:
if prefill_flag:
logger.info(f"[{request_id}] [prefill time: {(time.time() - gpt_stt):.4f}]")
gpt_stt = time.time()
prefill_flag = False
logger.info(f"[{request_id}] [decode time: {(time.time() - gpt_stt):.4f}] [decode len: {len(output.outputs[0].token_ids)}]")
codes = output.outputs[0].token_ids[:-2]
codes = torch.tensor(codes, device=text_inputs.device, dtype=torch.long).unsqueeze(0)
return codes, speech_conditioning_latent
def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, mel_codes_lengths, emo_speech_conditioning_latent,
cond_mel_lengths=None, emo_cond_mel_lengths=None, emo_vec=None, use_speed=None, do_spk_cond=False):
# TODO: 注意这里的speech_conditioning_latent.transpose(1,2),与v1不同,先支持一个参考音频,run起来先
if do_spk_cond:
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent.transpose(1,2), cond_mel_lengths)
else:
speech_conditioning_latent = speech_conditioning_latent
if emo_vec is None:
emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_mel_lengths)
emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
emo_vec = self.emo_layer(emo_vec_syn)
text_inputs = self.set_text_padding(text_inputs, text_lengths)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
duration_emb = self.speed_emb(torch.zeros_like(use_speed))
duration_emb_half = self.speed_emb(torch.ones_like(use_speed))
conds = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_emb = self.mel_embedding(mel_codes)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
emb = torch.cat([conds, text_emb, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
offset = conds.shape[1]
enc = gpt_out.last_hidden_state[:, offset:]
enc = self.final_norm(enc)
return enc[:, -mel_emb.shape[1]:][:, :-2]
def get_emovec(self, emo_speech_conditioning_latent, emo_cond_lengths):
emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_lengths)
emo_vec_syn = self.emovec_layer(emo_vec_syn_ori)
emo_vec = self.emo_layer(emo_vec_syn)
return emo_vec
def merge_emovec(self, speech_conditioning_latent, emo_speech_conditioning_latent, cond_lengths, emo_cond_lengths, alpha = 1.0):
emo_vec = self.get_emovec(emo_speech_conditioning_latent, emo_cond_lengths)
base_vec = self.get_emovec(speech_conditioning_latent, cond_lengths)
out = base_vec + alpha * (emo_vec - base_vec)
return out
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
from collections import namedtuple
from functools import wraps
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from packaging import version
from torch import einsum, nn
def exists(val):
return val is not None
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(self, dropout=0.0, causal=False, use_flash=False):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash
assert not (
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
), "in order to use flash attention, you must be using pytorch 2.0 or above"
# determine efficient attention configs for cuda and cpu
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
self.cpu_config = self.config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
if device_properties.major == 8 and device_properties.minor == 0:
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
self.cuda_config = self.config(True, False, False)
else:
print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
self.cuda_config = self.config(False, True, True)
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def flash_attn(self, q, k, v, mask=None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if k.ndim == 3:
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
if v.ndim == 3:
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L
if exists(mask):
mask = rearrange(mask, "b j -> b 1 1 j")
mask = mask.expand(-1, heads, q_len, -1)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal
)
return out
def forward(self, q, k, v, mask=None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device = q.shape[-2], q.device
scale = q.shape[-1] ** -0.5
if self.use_flash:
return self.flash_attn(q, k, v, mask=mask)
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
# similarity
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# key padding mask
if exists(mask):
mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# causal mask
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
def Sequential(*mods):
return nn.Sequential(*filter(exists, mods))
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
class RMSNorm(nn.Module):
def __init__(self, dim, scale=True, dim_cond=None):
super().__init__()
self.cond = exists(dim_cond)
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
def forward(self, x, cond=None):
gamma = default(self.gamma, 1)
out = F.normalize(x, dim=-1) * self.scale * gamma
if not self.cond:
return out
assert exists(cond)
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
return out * gamma + beta
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
(kernel_size,) = self.kernel_size
(dilation,) = self.dilation
(stride,) = self.stride
assert stride == 1
self.causal_padding = dilation * (kernel_size - 1)
def forward(self, x):
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
return super().forward(causal_padded_x)
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.gelu(gate) * x
def FeedForward(dim, mult=4, causal_conv=False):
dim_inner = int(dim * mult * 2 / 3)
conv = None
if causal_conv:
conv = nn.Sequential(
Rearrange("b n d -> b d n"),
CausalConv1d(dim_inner, dim_inner, 3),
Rearrange("b d n -> b n d"),
)
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
class PerceiverResampler(nn.Module):
def __init__(
self,
dim,
depth=2,
dim_context=None,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
):
super().__init__()
dim_context = default(dim_context, dim)
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
self.latents = nn.Parameter(torch.randn(num_latents, dim))
nn.init.normal_(self.latents, std=0.02)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
use_flash=use_flash_attn,
cross_attn_include_queries=True,
),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.norm = RMSNorm(dim)
def forward(self, x, mask=None):
batch = x.shape[0]
x = self.proj_context(x)
latents = repeat(self.latents, "n d -> b n d", b=batch)
for attn, ff in self.layers:
latents = attn(latents, x, mask=mask) + latents
latents = ff(latents) + latents
return self.norm(latents)
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_context=None,
causal=False,
dim_head=64,
heads=8,
dropout=0.0,
use_flash=False,
cross_attn_include_queries=False,
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.cross_attn_include_queries = cross_attn_include_queries
dim_inner = dim_head * heads
dim_context = default(dim_context, dim)
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
self.to_q = nn.Linear(dim, dim_inner, bias=False)
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
self.to_out = nn.Linear(dim_inner, dim, bias=False)
def forward(self, x, context=None, mask=None):
h, has_context = self.heads, exists(context)
context = default(context, x)
if has_context and self.cross_attn_include_queries:
context = torch.cat((x, context), dim=-2)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = self.attend(q, k, v, mask=mask)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
import os
import re
import time
from subprocess import CalledProcessError
import traceback
from typing import List
import numpy as np
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from omegaconf import OmegaConf
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.front import TextNormalizer, TextTokenizer
class IndexTTS:
def __init__(
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, use_cuda_kernel=None,
):
"""
Args:
cfg_path (str): path to the config file.
model_dir (str): path to the model directory.
is_fp16 (bool): whether to use fp16.
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
"""
if device is not None:
self.device = device
self.is_fp16 = False if device == "cpu" else is_fp16
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
elif torch.cuda.is_available():
self.device = "cuda:0"
self.is_fp16 = is_fp16
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
self.device = "mps"
self.is_fp16 = False # Use float16 on MPS is overhead than float32
self.use_cuda_kernel = False
else:
self.device = "cpu"
self.is_fp16 = False
self.use_cuda_kernel = False
print(">> Be patient, it may take a while to run in CPU mode.")
self.cfg = OmegaConf.load(cfg_path)
self.model_dir = model_dir
self.dtype = torch.float16 if self.is_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
self.gpt = UnifiedVoice(**self.cfg.gpt)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device)
if self.is_fp16:
self.gpt.eval().half()
else:
self.gpt.eval()
print(">> GPT weights restored from:", self.gpt_path)
if self.use_cuda_kernel:
# preload the CUDA kernel for BigVGAN
try:
from indextts.BigVGAN.alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
except Exception as ex:
traceback.print_exc()
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
self.use_cuda_kernel = False
self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel)
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu")
self.bigvgan.load_state_dict(vocoder_dict["generator"])
self.bigvgan = self.bigvgan.to(self.device)
# remove weight norm on eval mode
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()
print(">> bigvgan weights restored from:", self.bigvgan_path)
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
self.normalizer = TextNormalizer()
self.normalizer.load()
print(">> TextNormalizer loaded")
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
print(">> bpe model loaded from:", self.bpe_path)
# 缓存参考音频mel:
self.cache_audio_prompt = None
self.cache_cond_mel = None
# 进度引用显示(可选)
self.gr_progress = None
def remove_long_silence(self, codes: torch.Tensor, latent: torch.Tensor, silent_token=52, max_consecutive=30):
code_lens = []
codes_list = []
device = codes.device
dtype = codes.dtype
isfix = False
for i in range(0, codes.shape[0]):
code = codes[i]
if self.cfg.gpt.stop_mel_token not in code:
code_lens.append(len(code))
len_ = len(code)
else:
# len_ = code.cpu().tolist().index(8193)+1
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
len_ = len_ - 2
count = torch.sum(code == silent_token).item()
if count > max_consecutive:
code = code.cpu().tolist()
ncode = []
n = 0
for k in range(0, len_):
if code[k] != silent_token:
ncode.append(code[k])
n = 0
elif code[k] == silent_token and n < 10:
ncode.append(code[k])
n += 1
# if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52):
# n += 1
len_ = len(ncode)
ncode = torch.LongTensor(ncode)
codes_list.append(ncode.to(device, dtype=dtype))
isfix = True
# codes[i] = self.stop_mel_token
# codes[i, 0:len_] = ncode
else:
codes_list.append(codes[i])
code_lens.append(len_)
codes = pad_sequence(codes_list, batch_first=True) if isfix else codes[:, :-2]
code_lens = torch.LongTensor(code_lens).to(device, dtype=dtype)
return codes, code_lens
def _set_gr_progress(self, value, desc):
if self.gr_progress is not None:
self.gr_progress(value, desc=desc)
# 原始推理模式
def infer(self, audio_prompt, text, output_path, verbose=False):
print(">> start inference...")
self._set_gr_progress(0, "start inference...")
if verbose:
print(f"origin text:{text}")
start_time = time.perf_counter()
# 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度
if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
audio, sr = torchaudio.load(audio_prompt)
audio = torch.mean(audio, dim=0, keepdim=True)
if audio.shape[0] > 1:
audio = audio[0].unsqueeze(0)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
cond_mel_frame = cond_mel.shape[-1]
if verbose:
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
self.cache_audio_prompt = audio_prompt
self.cache_cond_mel = cond_mel
else:
cond_mel = self.cache_cond_mel
cond_mel_frame = cond_mel.shape[-1]
pass
auto_conditioning = cond_mel
text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list)
if verbose:
print("text token count:", len(text_tokens_list))
print("sentences count:", len(sentences))
print(*sentences, sep="\n")
top_p = 0.8
top_k = 30
temperature = 1.0
autoregressive_batch_size = 1
length_penalty = 0.0
num_beams = 1
repetition_penalty = 10.0
max_mel_tokens = 600
sampling_rate = 24000
# lang = "EN"
# lang = "ZH"
wavs = []
gpt_gen_time = 0
bigvgan_time = 0
speech_conditioning_latent = self.gpt.get_conditioning(
auto_conditioning.half(),
torch.tensor([auto_conditioning.shape[-1]], device=self.device)
)
for sent in sentences:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
# text_tokens = F.pad(text_tokens, (1, 0), value=0)
# text_tokens = F.pad(text_tokens, (0, 1), value=1)
if verbose:
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
print("text_token_syms is same as sentence tokens", text_token_syms == sent)
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
# print(text_len)
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
codes, latent = self.gpt.inference_speech(speech_conditioning_latent, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
# text_lengths=text_len,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=autoregressive_batch_size,
length_penalty=length_penalty,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens)
gpt_gen_time += time.perf_counter() - m_start_time
# code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
# if verbose:
# print(codes, type(codes))
# print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
# print(f"code len: {code_lens}")
# # remove ultra-long silence if exits
# # temporarily fix the long silence bug.
# codes, code_lens = self.remove_long_silence(codes, latent, silent_token=52, max_consecutive=30)
# if verbose:
# print(codes, type(codes))
# print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
# print(f"code len: {code_lens}")
m_start_time = time.perf_counter()
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
bigvgan_time += time.perf_counter() - m_start_time
wav = wav.squeeze(1)
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
# wavs.append(wav[:, :-512])
wavs.append(wav.cpu()) # to cpu before saving
end_time = time.perf_counter()
wav = torch.cat(wavs, dim=1)
wav_length = wav.shape[-1] / sampling_rate
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
print(f">> Total inference time: {end_time - start_time:.2f} seconds")
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio
wav = wav.cpu() # to cpu
if output_path:
# 直接保存音频到指定路径中
if os.path.isfile(output_path):
os.remove(output_path)
print(">> remove old wav file:", output_path)
if os.path.dirname(output_path) != "":
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
return output_path
else:
# 返回以符合Gradio的格式要求
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
return (sampling_rate, wav_data)
if __name__ == "__main__":
prompt_wav="test_data/input.wav"
#text="晕 XUAN4 是 一 种 GAN3 觉"
#text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!'
text="There is a vehicle arriving in dock number 7?"
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, use_cuda_kernel=False)
tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
\ No newline at end of file
import os
import re
import time
from subprocess import CalledProcessError
import traceback
from typing import List
import numpy as np
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from omegaconf import OmegaConf
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model_vllm import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.front import TextNormalizer, TextTokenizer
import matplotlib.pyplot as plt
# def fade_in_out(wav, fade_in=int(24000*0.05), fade_out=int(24000*0.05)):
# wav = wav.astype(np.float32)
# print("wav", np.abs(wav).max(), np.abs(wav).mean(), np.abs(wav).min())
# if fade_in > 0:
# wav[:fade_in] *= np.linspace(0, 1, fade_in)[:, None]
# if fade_out > 0:
# wav[-fade_out:] *= np.linspace(1, 0, fade_out)[:, None]
# wav = np.clip(wav, -32768, 32767).astype(np.int16)
# wav = np.concatenate([np.zeros((int(0.4 * 24000), 1)), wav], axis=0).astype(np.int16)
# return wav
def trim_and_pad_silence(wav_data, threshold=1000, min_silence=int(24000*0.4)):
# # 1. 去除前端静音
# abs_data = np.abs(wav_data).flatten()
# first_non_silent = np.argmax(abs_data >= threshold) # 第一个≥threshold的索引
# wav_data = wav_data[max(0, first_non_silent-int(24000*0.1)):] # 切片保留后端
# 2. 处理后端静音
abs_trimmed = np.abs(wav_data).flatten()
last_non_silent = len(abs_trimmed) - np.argmax(abs_trimmed[::-1] >= threshold) # 最后一个≥threshold的索引+1
# 计算后端静音长度
back_silence_length = len(wav_data) - last_non_silent
if back_silence_length < min_silence:
pad_length = min_silence - back_silence_length
padded = np.vstack([wav_data, np.zeros((pad_length, 1))]) # 补0
else:
padded = wav_data
return padded.astype(np.int16)
class IndexTTS:
def __init__(
self, model_dir="checkpoints", is_fp16=True, device=None, use_cuda_kernel=None, gpu_memory_utilization=0.25
):
"""
Args:
cfg_path (str): path to the config file.
model_dir (str): path to the model directory.
is_fp16 (bool): whether to use fp16.
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
"""
if device is not None:
self.device = device
self.is_fp16 = False if device == "cpu" else is_fp16
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
elif torch.cuda.is_available():
self.device = "cuda:0"
self.is_fp16 = is_fp16
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
self.device = "mps"
self.is_fp16 = False # Use float16 on MPS is overhead than float32
self.use_cuda_kernel = False
else:
self.device = "cpu"
self.is_fp16 = False
self.use_cuda_kernel = False
print(">> Be patient, it may take a while to run in CPU mode.")
cfg_path = os.path.join(model_dir, "config.yaml")
self.cfg = OmegaConf.load(cfg_path)
self.model_dir = model_dir
self.dtype = torch.float16 if self.is_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
vllm_dir = os.path.join(model_dir, "gpt")
engine_args = AsyncEngineArgs(
model=vllm_dir,
tensor_parallel_size=1,
dtype="auto",
gpu_memory_utilization=gpu_memory_utilization,
# enforce_eager=True,
)
indextts_vllm = AsyncLLM.from_engine_args(engine_args)
self.gpt = UnifiedVoice(indextts_vllm, **self.cfg.gpt, model_dir=model_dir)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device)
# if self.is_fp16:
# self.gpt.eval().half()
# else:
# self.gpt.eval()
self.gpt.eval()
print(">> GPT weights restored from:", self.gpt_path)
if self.use_cuda_kernel:
# preload the CUDA kernel for BigVGAN
try:
from indextts.BigVGAN.alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
except Exception as ex:
traceback.print_exc()
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
self.use_cuda_kernel = False
self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel)
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu")
self.bigvgan.load_state_dict(vocoder_dict["generator"])
self.bigvgan = self.bigvgan.to(self.device)
# remove weight norm on eval mode
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()
print(">> bigvgan weights restored from:", self.bigvgan_path)
self.bpe_path = os.path.join(self.model_dir, "bpe.model") # self.cfg.dataset["bpe_model"]
self.normalizer = TextNormalizer()
self.normalizer.load()
print(">> TextNormalizer loaded")
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
print(">> bpe model loaded from:", self.bpe_path)
self.speaker_dict = {}
def remove_long_silence(self, codes: list, latent: torch.Tensor, max_consecutive=15, silent_token=52):
assert latent.dim() == 3 and latent.size(0) == 1, "Latent should be (1, seq_len, dim)"
seq_len, dim = latent.size(1), latent.size(2)
# print("latent", latent.shape)
if self.stop_mel_token in codes:
try:
stop_idx = codes.index(self.stop_mel_token)
valid_len = max(stop_idx - 1, 0) # 保留至停止标记前一位
except ValueError:
valid_len = len(codes)
else:
valid_len = len(codes)
valid_codes = codes[:min(valid_len, len(codes))]
valid_latent = latent[0, :seq_len] # 保持维度兼容性
keep_indices = []
silence_counter = 0
for idx, token in enumerate(valid_codes):
if token == silent_token:
silence_counter += 1
else:
silence_counter = 0
if silence_counter <= max_consecutive:
keep_indices.append(idx)
filtered_latent = valid_latent[keep_indices].unsqueeze(0) # [1, new_seq, dim]
# print("filtered_latent", filtered_latent.shape)
return filtered_latent
async def infer(self, audio_prompt: List[str], text, output_path=None, verbose=False, seed=None):
print(">> start inference...")
start_time = time.perf_counter()
auto_conditioning = []
for ap_ in audio_prompt:
audio, sr = torchaudio.load(ap_)
audio = torch.mean(audio, dim=0, keepdim=True)
if audio.shape[0] > 1:
audio = audio[0].unsqueeze(0)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
# cond_mel_frame = cond_mel.shape[-1]
auto_conditioning.append(cond_mel)
text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list)
sampling_rate = 24000
# lang = "EN"
# lang = "ZH"
wavs = []
gpt_gen_time = 0
bigvgan_time = 0
speech_conditioning_latent = []
for cond_mel in auto_conditioning:
speech_conditioning_latent_ = self.gpt.get_conditioning(
cond_mel, # .half()
torch.tensor([cond_mel.shape[-1]], device=self.device)
)
speech_conditioning_latent.append(speech_conditioning_latent_)
speech_conditioning_latent = torch.stack(speech_conditioning_latent).sum(dim=0)
speech_conditioning_latent = speech_conditioning_latent / len(auto_conditioning)
for sent in sentences:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
m_start_time = time.perf_counter()
with torch.no_grad():
# 设置采样参数的seed
if seed is not None:
self.gpt.sampling_params.seed = int(seed)
else:
self.gpt.sampling_params.seed = None
codes = await self.gpt.inference_speech(
speech_conditioning_latent,
text_tokens,
# cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device)
)
gpt_gen_time += time.perf_counter() - m_start_time
# # remove ultra-long silence if exits
# # temporarily fix the long silence bug.
# latent = self.remove_long_silence(codes, latent)
codes = torch.tensor(codes, dtype=torch.long, device=self.device).unsqueeze(0)
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
latent = self.gpt(speech_conditioning_latent, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
code_lens*self.gpt.mel_length_compression,
cond_mel_lengths=torch.tensor([speech_conditioning_latent.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
m_start_time = time.perf_counter()
wav, _ = self.bigvgan(latent, [ap_.transpose(1, 2) for ap_ in auto_conditioning])
bigvgan_time += time.perf_counter() - m_start_time
wav = wav.squeeze(1)
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
# wavs.append(wav[:, :-512])
wavs.append(wav.cpu()) # to cpu before saving
torch.cuda.empty_cache()
end_time = time.perf_counter()
wav = torch.cat(wavs, dim=1)
wav_length = wav.shape[-1] / sampling_rate
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
print(f">> Total inference time: {end_time - start_time:.2f} seconds")
print(f">> Generated audio length: {wav_length:.2f} seconds")
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio
wav = wav.cpu() # to cpu
if output_path:
# 直接保存音频到指定路径中
if os.path.isfile(output_path):
os.remove(output_path)
print(">> remove old wav file:", output_path)
if os.path.dirname(output_path) != "":
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
return output_path
else:
# 返回以符合Gradio的格式要求
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
wav_data = trim_and_pad_silence(wav_data)
return (sampling_rate, wav_data)
async def infer_with_ref_audio_embed(self, speaker: str, text):
start_time = time.perf_counter()
text = text.replace("嗯", "EN4")
text = text.replace("嘿", "HEI1")
text = text.replace("嗨", "HAI4")
text = text.replace("哈哈", "HA1HA1")
sampling_rate = 24000
auto_conditioning = self.speaker_dict[speaker]["auto_conditioning"]
text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list)
wavs = []
gpt_gen_time = 0
bigvgan_time = 0
speech_conditioning_latent = self.speaker_dict[speaker]["speech_conditioning_latent"]
for sent in sentences:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
m_start_time = time.perf_counter()
with torch.no_grad():
codes = await self.gpt.inference_speech(
speech_conditioning_latent,
text_tokens,
# cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device)
)
gpt_gen_time += time.perf_counter() - m_start_time
# # remove ultra-long silence if exits
# # temporarily fix the long silence bug.
# latent = self.remove_long_silence(codes, latent)
codes = torch.tensor(codes, dtype=torch.long, device=self.device).unsqueeze(0)
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
latent = self.gpt(speech_conditioning_latent, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
code_lens*self.gpt.mel_length_compression,
cond_mel_lengths=torch.tensor([speech_conditioning_latent.shape[-1]], device=text_tokens.device),
return_latent=True, clip_inputs=False)
m_start_time = time.perf_counter()
wav, _ = self.bigvgan(latent, [ap_.transpose(1, 2) for ap_ in auto_conditioning])
bigvgan_time += time.perf_counter() - m_start_time
wav = wav.squeeze(1)
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
# wavs.append(wav[:, :-512])
wavs.append(wav) # to cpu before saving
torch.cuda.empty_cache()
end_time = time.perf_counter()
wav = torch.cat(wavs, dim=1)
# wav_length = wav.shape[-1] / sampling_rate
# # print(f">> Total inference time: {end_time - start_time:.2f} seconds")
# print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
# print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
# print(f">> Total inference time: {end_time - start_time:.2f} seconds")
# print(f">> Generated audio length: {wav_length:.2f} seconds")
# print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio
wav = wav.cpu() # to cpu
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
wav_data = trim_and_pad_silence(wav_data)
return (sampling_rate, wav_data)
@torch.no_grad()
def registry_speaker(self, speaker: str, audio_paths: List[str]):
auto_conditioning = []
for ap_ in audio_paths:
audio, sr = torchaudio.load(ap_)
audio = torch.mean(audio, dim=0, keepdim=True)
if audio.shape[0] > 1:
audio = audio[0].unsqueeze(0)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
# cond_mel_frame = cond_mel.shape[-1]
auto_conditioning.append(cond_mel)
speech_conditioning_latent = []
for cond_mel in auto_conditioning:
speech_conditioning_latent_ = self.gpt.get_conditioning(
cond_mel, # .half()
torch.tensor([cond_mel.shape[-1]], device=self.device)
)
speech_conditioning_latent.append(speech_conditioning_latent_)
speech_conditioning_latent = torch.stack(speech_conditioning_latent).sum(dim=0)
speech_conditioning_latent = speech_conditioning_latent / len(auto_conditioning)
self.speaker_dict[speaker] = {
"auto_conditioning": auto_conditioning,
"speech_conditioning_latent": speech_conditioning_latent
}
print(f"Speaker: {speaker} registered")
import os
import random
import re
import time
import traceback
from typing import List
import uuid
import librosa
import torch
import torchaudio
# from torch.nn.utils.rnn import pad_sequence
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import SeamlessM4TFeatureExtractor
from transformers import AutoTokenizer
from modelscope import AutoModelForCausalLM
import safetensors
from loguru import logger
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
from indextts.BigVGAN.models import BigVGAN as Generator
from indextts.gpt.model_vllm_v2 import UnifiedVoice
from indextts.utils.checkpoint import load_checkpoint
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from indextts.utils.maskgct_utils import build_semantic_model, build_semantic_codec
from indextts.utils.front import TextNormalizer, TextTokenizer
from indextts.s2mel.modules.commons import load_checkpoint2, MyModel
from indextts.s2mel.modules.bigvgan import bigvgan
from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus
from indextts.s2mel.modules.audio import mel_spectrogram
import torch.nn.functional as F
from vllm import SamplingParams, TokensPrompt
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
class IndexTTS2:
def __init__(
self, model_dir="checkpoints", is_fp16=False, device=None, use_cuda_kernel=None, gpu_memory_utilization=0.25, qwenemo_gpu_memory_utilization=0.10
):
"""
Args:
cfg_path (str): path to the config file.
model_dir (str): path to the model directory.
is_fp16 (bool): whether to use fp16.
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
"""
if device is not None:
self.device = device
self.is_fp16 = False if device == "cpu" else is_fp16
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
elif torch.cuda.is_available():
self.device = "cuda:0"
self.is_fp16 = is_fp16
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
self.device = "mps"
self.is_fp16 = False # Use float16 on MPS is overhead than float32
self.use_cuda_kernel = False
else:
self.device = "cpu"
self.is_fp16 = False
self.use_cuda_kernel = False
logger.info(">> Be patient, it may take a while to run in CPU mode.")
cfg_path = os.path.join(model_dir, "config.yaml")
self.cfg = OmegaConf.load(cfg_path)
self.model_dir = model_dir
self.dtype = torch.float16 if self.is_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
vllm_dir = os.path.join(model_dir, "gpt")
engine_args = AsyncEngineArgs(
model=vllm_dir,
tensor_parallel_size=1,
dtype="auto",
gpu_memory_utilization=gpu_memory_utilization,
# enforce_eager=True,
)
indextts_vllm = AsyncLLM.from_engine_args(engine_args)
self.qwen_emo = QwenEmotion(
os.path.join(self.model_dir, self.cfg.qwen_emo_path),
gpu_memory_utilization=qwenemo_gpu_memory_utilization,
)
self.gpt = UnifiedVoice(indextts_vllm, **self.cfg.gpt)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
load_checkpoint(self.gpt, self.gpt_path)
self.gpt = self.gpt.to(self.device)
# if self.is_fp16:
# self.gpt.eval().half()
# else:
# self.gpt.eval()
self.gpt.eval()
logger.info(f">> GPT weights restored from: {self.gpt_path}")
if self.use_cuda_kernel:
# preload the CUDA kernel for BigVGAN
try:
from indextts.BigVGAN.alias_free_activation.cuda import load
anti_alias_activation_cuda = load.load()
logger.info(f">> Preload custom CUDA kernel for BigVGAN {anti_alias_activation_cuda}")
except Exception as ex:
traceback.print_exc()
logger.info(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
self.use_cuda_kernel = False
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
# "facebook/w2v-bert-2.0"
os.path.join(self.model_dir, "w2v-bert-2.0")
)
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
os.path.join(self.model_dir, self.cfg.w2v_stat),
os.path.join(self.model_dir, "w2v-bert-2.0")
)
self.semantic_model = self.semantic_model.to(self.device)
self.semantic_model.eval()
self.semantic_mean = self.semantic_mean.to(self.device)
self.semantic_std = self.semantic_std.to(self.device)
semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
semantic_code_ckpt = os.path.join(self.model_dir, "semantic_codec/model.safetensors")
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
self.semantic_codec = semantic_codec.to(self.device)
self.semantic_codec.eval()
logger.info('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
s2mel, _, _, _ = load_checkpoint2(
s2mel,
None,
s2mel_path,
load_only_params=True,
ignore_modules=[],
is_distributed=False,
)
self.s2mel = s2mel.to(self.device)
self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
self.s2mel.eval()
logger.info(f">> s2mel weights restored from: {s2mel_path}")
# load campplus_model
# campplus_ckpt_path = hf_hub_download(
# "funasr/campplus", filename="campplus_cn_common.bin", cache_dir=os.path.join(self.model_dir, "campplus")
# )
campplus_ckpt_path = os.path.join(self.model_dir, "campplus/campplus_cn_common.bin")
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
self.campplus_model = campplus_model.to(self.device)
self.campplus_model.eval()
logger.info(f">> campplus_model weights restored from: {campplus_ckpt_path}")
bigvgan_name = self.cfg.vocoder.name
# self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False, cache_dir=os.path.join(self.model_dir, "bigvgan"))
self.bigvgan = bigvgan.BigVGAN.from_pretrained(os.path.join(self.model_dir, "bigvgan"))
self.bigvgan = self.bigvgan.to(self.device)
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()
logger.info(f">> bigvgan weights restored from: {bigvgan_name}")
self.bpe_path = os.path.join(self.model_dir, "bpe.model") # self.cfg.dataset["bpe_model"]
self.normalizer = TextNormalizer()
self.normalizer.load()
logger.info(">> TextNormalizer loaded")
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
logger.info(f">> bpe model loaded from: {self.bpe_path}")
emo_matrix = torch.load(os.path.join(self.model_dir, self.cfg.emo_matrix))
self.emo_matrix = emo_matrix.to(self.device)
self.emo_num = list(self.cfg.emo_num)
spk_matrix = torch.load(os.path.join(self.model_dir, self.cfg.spk_matrix))
self.spk_matrix = spk_matrix.to(self.device)
self.emo_matrix = torch.split(self.emo_matrix, self.emo_num)
self.spk_matrix = torch.split(self.spk_matrix, self.emo_num)
mel_fn_args = {
"n_fft": self.cfg.s2mel['preprocess_params']['spect_params']['n_fft'],
"win_size": self.cfg.s2mel['preprocess_params']['spect_params']['win_length'],
"hop_size": self.cfg.s2mel['preprocess_params']['spect_params']['hop_length'],
"num_mels": self.cfg.s2mel['preprocess_params']['spect_params']['n_mels'],
"sampling_rate": self.cfg.s2mel["preprocess_params"]["sr"],
"fmin": self.cfg.s2mel['preprocess_params']['spect_params'].get('fmin', 0),
"fmax": None if self.cfg.s2mel['preprocess_params']['spect_params'].get('fmax', "None") == "None" else 8000,
"center": False
}
self.mel_fn = lambda x: mel_spectrogram(x, **mel_fn_args)
self.speaker_dict = {}
@torch.no_grad()
def get_emb(self, input_features, attention_mask):
vq_emb = self.semantic_model(
input_features=input_features,
attention_mask=attention_mask,
output_hidden_states=True,
)
feat = vq_emb.hidden_states[17] # (B, T, C)
feat = (feat - self.semantic_mean) / self.semantic_std
return feat
def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
"""
Insert silences between sentences.
wavs: List[torch.tensor]
"""
if not wavs or interval_silence <= 0:
return wavs
# get channel_size
channel_size = wavs[0].size(0)
# get silence tensor
sil_dur = int(sampling_rate * interval_silence / 1000.0)
sil_tensor = torch.zeros(channel_size, sil_dur)
wavs_list = []
for i, wav in enumerate(wavs):
wavs_list.append(wav)
if i < len(wavs) - 1:
wavs_list.append(sil_tensor)
return wavs_list
async def infer(self, spk_audio_prompt, text, output_path,
emo_audio_prompt=None, emo_alpha=1.0,
emo_vector=None,
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs):
logger.info(">> start inference...")
start_time = time.perf_counter()
if use_emo_text:
emo_audio_prompt = None
emo_alpha = 1.0
# assert emo_audio_prompt is None
# assert emo_alpha == 1.0
if emo_text is None:
emo_text = text
emo_dict, content = await self.qwen_emo.inference(emo_text)
# logger.info(emo_dict)
emo_vector = list(emo_dict.values())
if emo_vector is not None:
emo_audio_prompt = None
emo_alpha = 1.0
# assert emo_audio_prompt is None
# assert emo_alpha == 1.0
if emo_audio_prompt is None:
emo_audio_prompt = spk_audio_prompt
emo_alpha = 1.0
# assert emo_alpha == 1.0
audio, sr = librosa.load(spk_audio_prompt)
audio = torch.tensor(audio).unsqueeze(0)
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
inputs = self.extract_features(audio_16k, sampling_rate=16000, return_tensors="pt")
input_features = inputs["input_features"]
attention_mask = inputs["attention_mask"]
input_features = input_features.to(self.device)
attention_mask = attention_mask.to(self.device)
spk_cond_emb = self.get_emb(input_features, attention_mask)
_, S_ref = self.semantic_codec.quantize(spk_cond_emb)
ref_mel = self.mel_fn(audio_22k.to(spk_cond_emb.device).float())
ref_target_lengths = torch.LongTensor([ref_mel.size(2)]).to(ref_mel.device)
feat = torchaudio.compliance.kaldi.fbank(audio_16k.to(ref_mel.device),
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True) # feat2另外一个滤波器能量组特征[922, 80]
style = self.campplus_model(feat.unsqueeze(0)) # 参考音频的全局style2[1,192]
prompt_condition = self.s2mel.models['length_regulator'](S_ref,
ylens=ref_target_lengths,
n_quantizers=3,
f0=None)[0]
if emo_vector is not None:
weight_vector = torch.tensor(emo_vector).to(self.device)
if use_random:
random_index = [random.randint(0, x - 1) for x in self.emo_num]
else:
random_index = [find_most_similar_cosine(style, tmp) for tmp in self.spk_matrix]
emo_matrix = [tmp[index].unsqueeze(0) for index, tmp in zip(random_index, self.emo_matrix)]
emo_matrix = torch.cat(emo_matrix, 0)
emovec_mat = weight_vector.unsqueeze(1) * emo_matrix
emovec_mat = torch.sum(emovec_mat, 0)
emovec_mat = emovec_mat.unsqueeze(0)
emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000)
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
emo_input_features = emo_inputs["input_features"]
emo_attention_mask = emo_inputs["attention_mask"]
emo_input_features = emo_input_features.to(self.device)
emo_attention_mask = emo_attention_mask.to(self.device)
emo_cond_emb = self.get_emb(emo_input_features, emo_attention_mask)
text_tokens_list = self.tokenizer.tokenize(text)
sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence)
if verbose:
print("text_tokens_list:", text_tokens_list)
print("sentences count:", len(sentences))
print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
print(*sentences, sep="\n")
sampling_rate = 22050
wavs = []
gpt_gen_time = 0
gpt_forward_time = 0
s2mel_time = 0
bigvgan_time = 0
has_warned = False
for sent in sentences:
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
if verbose:
print(text_tokens)
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
# debug tokenizer
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
print("text_token_syms is same as sentence tokens", text_token_syms == sent)
m_start_time = time.perf_counter()
with torch.no_grad():
emovec = self.gpt.merge_emovec(
spk_cond_emb,
emo_cond_emb,
torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
alpha=emo_alpha
)
if emo_vector is not None:
emovec = emovec_mat + (1 - torch.sum(weight_vector)) * emovec
# emovec = emovec_mat
codes, speech_conditioning_latent = await self.gpt.inference_speech(
spk_cond_emb,
text_tokens,
emo_cond_emb,
cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
emo_vec=emovec,
)
gpt_gen_time += time.perf_counter() - m_start_time
# if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
# warnings.warn(
# f"WARN: generation stopped due to exceeding `max_mel_tokens` ({self.cfg.gpt.max_mel_tokens}). "
# f"Current output shape: {codes.shape}. "
# f"Input text tokens: {text_tokens.shape[1]}. "
# f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
# category=RuntimeWarning
# )
# has_warned = True
# codes = torch.tensor(codes, dtype=torch.long, device=self.device).unsqueeze(0)
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
code_lens = []
for code in codes:
if self.stop_mel_token not in code:
# code_lens.append(len(code))
code_len = len(code)
else:
len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1
code_len = len_ - 1
code_lens.append(code_len)
codes = codes[:, :code_len]
code_lens = torch.LongTensor(code_lens)
code_lens = code_lens.to(self.device)
if verbose:
print(codes, type(codes))
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
print(f"code len: {code_lens}")
m_start_time = time.perf_counter()
use_speed = torch.zeros(spk_cond_emb.size(0)).to(spk_cond_emb.device).long()
# latent = self.gpt(speech_conditioning_latent, text_tokens,
# torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
# code_lens*self.gpt.mel_length_compression,
# cond_mel_lengths=torch.tensor([speech_conditioning_latent.shape[-1]], device=text_tokens.device),
# return_latent=True, clip_inputs=False)
latent = self.gpt(
speech_conditioning_latent,
text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
codes,
torch.tensor([codes.shape[-1]], device=text_tokens.device),
emo_cond_emb,
cond_mel_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
emo_cond_mel_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
emo_vec=emovec,
use_speed=use_speed,
)
gpt_forward_time += time.perf_counter() - m_start_time
dtype = None
with torch.amp.autocast(text_tokens.device.type, enabled=dtype is not None, dtype=dtype):
m_start_time = time.perf_counter()
diffusion_steps = 25
inference_cfg_rate = 0.7
latent = self.s2mel.models['gpt_layer'](latent)
S_infer = self.semantic_codec.quantizer.vq2emb(codes.unsqueeze(1))
S_infer = S_infer.transpose(1, 2)
S_infer = S_infer + latent
target_lengths = (code_lens * 1.72).long()
cond = self.s2mel.models['length_regulator'](S_infer,
ylens=target_lengths,
n_quantizers=3,
f0=None)[0]
cat_condition = torch.cat([prompt_condition, cond], dim=1)
vc_target = self.s2mel.models['cfm'].inference(cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(
cond.device),
ref_mel, style, None, diffusion_steps,
inference_cfg_rate=inference_cfg_rate)
vc_target = vc_target[:, :, ref_mel.size(-1):]
s2mel_time += time.perf_counter() - m_start_time
m_start_time = time.perf_counter()
wav = self.bigvgan(vc_target.float()).squeeze().unsqueeze(0)
bigvgan_time += time.perf_counter() - m_start_time
wav = wav.squeeze(1)
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
if verbose:
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
# wavs.append(wav[:, :-512])
# logger.error(f"time per token: {wav.shape[-1] / sampling_rate / codes.shape[-1]}, {wav.shape[-1] / sampling_rate / vc_target.shape[-1]}")
wavs.append(wav.cpu()) # to cpu before saving
end_time = time.perf_counter()
wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
wav = torch.cat(wavs, dim=1)
wav_length = wav.shape[-1] / sampling_rate
logger.info(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
logger.info(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
logger.info(f">> s2mel_time: {s2mel_time:.2f} seconds")
logger.info(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
logger.info(f">> Total inference time: {end_time - start_time:.2f} seconds")
logger.info(f">> Generated audio length: {wav_length:.2f} seconds")
logger.info(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
# save audio
wav = wav.cpu() # to cpu
if output_path:
# 直接保存音频到指定路径中
if os.path.isfile(output_path):
os.remove(output_path)
logger.info(f">> remove old wav file: {output_path}")
if os.path.dirname(output_path) != "":
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
logger.info(f">> wav file saved to: {output_path}")
return output_path
else:
# 返回以符合Gradio的格式要求
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
return (sampling_rate, wav_data)
def find_most_similar_cosine(query_vector, matrix):
query_vector = query_vector.float()
matrix = matrix.float()
similarities = F.cosine_similarity(query_vector, matrix, dim=1)
most_similar_index = torch.argmax(similarities)
return most_similar_index
class QwenEmotion:
def __init__(self, model_dir, gpu_memory_utilization=0.1):
self.model_dir = model_dir
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
# self.model = AutoModelForCausalLM.from_pretrained(
# self.model_dir,
# torch_dtype="float16", # "auto"
# # device_map="auto"
# )
# self.model = self.model.to("cuda")
engine_args = AsyncEngineArgs(
model=model_dir,
tensor_parallel_size=1,
dtype="auto",
gpu_memory_utilization=gpu_memory_utilization,
max_model_len=2048,
)
self.model = AsyncLLM.from_engine_args(engine_args)
self.prompt = "文本情感分类"
self.convert_dict = {
"愤怒": "angry",
"高兴": "happy",
"恐惧": "fear",
"反感": "hate",
"悲伤": "sad",
"低落": "low",
"惊讶": "surprise",
"自然": "neutral",
}
self.backup_dict = {"happy": 0, "angry": 0, "sad": 0, "fear": 0, "hate": 0, "low": 0, "surprise": 0,
"neutral": 1.0}
self.max_score = 1.2
self.min_score = 0.0
def convert(self, content):
content = content.replace("\n", " ")
content = content.replace(" ", "")
content = content.replace("{", "")
content = content.replace("}", "")
content = content.replace('"', "")
parts = content.strip().split(',')
# print(parts)
parts_dict = {}
desired_order = ["高兴", "愤怒", "悲伤", "恐惧", "反感", "低落", "惊讶", "自然"]
for part in parts:
key_value = part.strip().split(':')
if len(key_value) == 2:
parts_dict[key_value[0].strip()] = part
# 按照期望顺序重新排列
ordered_parts = [parts_dict[key] for key in desired_order if key in parts_dict]
parts = ordered_parts
if len(parts) != len(self.convert_dict):
return self.backup_dict
emotion_dict = {}
for part in parts:
key_value = part.strip().split(':')
if len(key_value) == 2:
try:
key = self.convert_dict[key_value[0].strip()]
value = float(key_value[1].strip())
value = max(self.min_score, min(self.max_score, value))
emotion_dict[key] = value
except Exception:
continue
for key in self.backup_dict:
if key not in emotion_dict:
emotion_dict[key] = 0.0
if sum(emotion_dict.values()) <= 0:
return self.backup_dict
return emotion_dict
async def inference(self, text_input):
messages = [
{"role": "system", "content": f"{self.prompt}"},
{"role": "user", "content": f"{text_input}"}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
model_inputs = self.tokenizer(text)["input_ids"]
# model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
# conduct text completion
# generated_ids = self.model.generate(
# **model_inputs,
# max_new_tokens=32768,
# pad_token_id=self.tokenizer.eos_token_id
# )
# output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
sampling_params = SamplingParams(
max_tokens=2048, # 32768
)
tokens_prompt = TokensPrompt(prompt_token_ids=model_inputs)
output_generator = self.model.generate(tokens_prompt, sampling_params=sampling_params, request_id=uuid.uuid4().hex)
async for output in output_generator:
pass
output_ids = output.outputs[0].token_ids[:-2]
# parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
emotion_dict = self.convert(content)
return emotion_dict, content
\ No newline at end of file
__version__ = "1.0.0"
# preserved here for legacy reasons
__model_version__ = "latest"
import audiotools
audiotools.ml.BaseModel.INTERN += ["dac.**"]
audiotools.ml.BaseModel.EXTERN += ["einops"]
from . import nn
from . import model
from . import utils
from .model import DAC
from .model import DACFile
import sys
import argbind
from dac.utils import download
from dac.utils.decode import decode
from dac.utils.encode import encode
STAGES = ["encode", "decode", "download"]
def run(stage: str):
"""Run stages.
Parameters
----------
stage : str
Stage to run
"""
if stage not in STAGES:
raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
stage_fn = globals()[stage]
if stage == "download":
stage_fn()
return
stage_fn()
if __name__ == "__main__":
group = sys.argv.pop(1)
args = argbind.parse_args(group=group)
with argbind.scope(args):
run(group)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment