Commit a7785cc6 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

delete soft link

parent 9a2a05ca
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer)
# Squeezeformer(https://github.com/upskyy/Squeezeformer)
# NeMo(https://github.com/NVIDIA/NeMo)
"""DepthwiseConv2dSubsampling4 and TimeReductionLayer definition."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from wenet.transformer.subsampling import BaseSubsampling
from typing import Tuple
from wenet.squeezeformer.conv2d import Conv2dValid
class DepthwiseConv2dSubsampling4(BaseSubsampling):
"""Depthwise Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
pos_enc_class (nn.Module): position encoding class.
dw_stride (int): Whether do depthwise convolution.
input_size (int): filter bank dimension.
"""
def __init__(
self, idim: int, odim: int,
pos_enc_class: torch.nn.Module,
dw_stride: bool = False,
input_size: int = 80,
input_dropout_rate: float = 0.1,
init_weights: bool = True
):
super(DepthwiseConv2dSubsampling4, self).__init__()
self.idim = idim
self.odim = odim
self.pw_conv = nn.Conv2d(
in_channels=idim, out_channels=odim, kernel_size=3, stride=2)
self.act1 = nn.ReLU()
self.dw_conv = nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2,
groups=odim if dw_stride else 1
)
self.act2 = nn.ReLU()
self.pos_enc = pos_enc_class
self.input_proj = nn.Sequential(
nn.Linear(
odim * (((input_size - 1) // 2 - 1) // 2), odim),
nn.Dropout(p=input_dropout_rate),
)
if init_weights:
linear_max = (odim * input_size / 4) ** -0.5
torch.nn.init.uniform_(
self.input_proj.state_dict()['0.weight'], -linear_max, linear_max)
torch.nn.init.uniform_(
self.input_proj.state_dict()['0.bias'], -linear_max, linear_max)
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.pw_conv(x)
x = self.act1(x)
x = self.dw_conv(x)
x = self.act2(x)
b, c, t, f = x.size()
x = x.permute(0, 2, 1, 3)
x = x.contiguous().view(b, t, c * f)
x, pos_emb = self.pos_enc(x, offset)
x = self.input_proj(x)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class TimeReductionLayer1D(nn.Module):
"""
Modified NeMo,
Squeezeformer Time Reduction procedure.
Downsamples the audio by `stride` in the time dimension.
Args:
channel (int): input dimension of
MultiheadAttentionMechanism and PositionwiseFeedForward
out_dim (int): Output dimension of the module.
kernel_size (int): Conv kernel size for
depthwise convolution in convolution module
stride (int): Downsampling factor in time dimension.
"""
def __init__(self, channel: int, out_dim: int,
kernel_size: int = 5, stride: int = 2):
super(TimeReductionLayer1D, self).__init__()
self.channel = channel
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.padding = max(0, self.kernel_size - self.stride)
self.dw_conv = nn.Conv1d(
in_channels=channel,
out_channels=channel,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
groups=channel,
)
self.pw_conv = nn.Conv1d(
in_channels=channel, out_channels=out_dim,
kernel_size=1, stride=1, padding=0, groups=1,
)
self.init_weights()
def init_weights(self):
dw_max = self.kernel_size ** -0.5
pw_max = self.channel ** -0.5
torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
def forward(self, xs, xs_lens: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
):
xs = xs.transpose(1, 2) # [B, C, T]
xs = xs.masked_fill(mask_pad.eq(0), 0.0)
xs = self.dw_conv(xs)
xs = self.pw_conv(xs)
xs = xs.transpose(1, 2) # [B, T, C]
B, T, D = xs.size()
mask = mask[:, ::self.stride, ::self.stride]
mask_pad = mask_pad[:, :, ::self.stride]
L = mask_pad.size(-1)
# For JIT exporting, we remove F.pad operator.
if L - T < 0:
xs = xs[:, :L - T, :].contiguous()
else:
dummy_pad = torch.zeros(B, L - T, D, device=xs.device)
xs = torch.cat([xs, dummy_pad], dim=1)
xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc')
return xs, xs_lens, mask, mask_pad
class TimeReductionLayer2D(nn.Module):
def __init__(
self, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256):
super(TimeReductionLayer2D, self).__init__()
self.encoder_dim = encoder_dim
self.kernel_size = kernel_size
self.dw_conv = Conv2dValid(
in_channels=encoder_dim,
out_channels=encoder_dim,
kernel_size=(kernel_size, 1),
stride=stride,
valid_trigy=True
)
self.pw_conv = Conv2dValid(
in_channels=encoder_dim,
out_channels=encoder_dim,
kernel_size=1,
stride=1,
valid_trigx=False,
valid_trigy=False,
)
self.kernel_size = kernel_size
self.stride = stride
self.init_weights()
def init_weights(self):
dw_max = self.kernel_size ** -0.5
pw_max = self.encoder_dim ** -0.5
torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
def forward(
self, xs: torch.Tensor, xs_lens: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0)
xs = xs.unsqueeze(2)
padding1 = self.kernel_size - self.stride
xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0),
mode='constant', value=0.)
xs = self.dw_conv(xs.permute(0, 3, 1, 2))
xs = self.pw_conv(xs).permute(0, 3, 2, 1).squeeze(1).contiguous()
tmp_length = xs.size(1)
xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc')
padding2 = max(0, (xs_lens.max() - tmp_length).data.item())
batch_size, hidden = xs.size(0), xs.size(-1)
dummy_pad = torch.zeros(batch_size, padding2, hidden, device=xs.device)
xs = torch.cat([xs, dummy_pad], dim=1)
mask = mask[:, ::2, ::2]
mask_pad = mask_pad[:, :, ::2]
return xs, xs_lens, mask, mask_pad
class TimeReductionLayerStream(nn.Module):
"""
Squeezeformer Time Reduction procedure.
Downsamples the audio by `stride` in the time dimension.
Args:
channel (int): input dimension of
MultiheadAttentionMechanism and PositionwiseFeedForward
out_dim (int): Output dimension of the module.
kernel_size (int): Conv kernel size for
depthwise convolution in convolution module
stride (int): Downsampling factor in time dimension.
"""
def __init__(self, channel: int, out_dim: int,
kernel_size: int = 1, stride: int = 2):
super(TimeReductionLayerStream, self).__init__()
self.channel = channel
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dw_conv = nn.Conv1d(
in_channels=channel,
out_channels=channel,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=channel,
)
self.pw_conv = nn.Conv1d(
in_channels=channel, out_channels=out_dim,
kernel_size=1, stride=1, padding=0, groups=1,
)
self.init_weights()
def init_weights(self):
dw_max = self.kernel_size ** -0.5
pw_max = self.channel ** -0.5
torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
def forward(self, xs, xs_lens: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
):
xs = xs.transpose(1, 2) # [B, C, T]
xs = xs.masked_fill(mask_pad.eq(0), 0.0)
xs = self.dw_conv(xs)
xs = self.pw_conv(xs)
xs = xs.transpose(1, 2) # [B, T, C]
B, T, D = xs.size()
mask = mask[:, ::self.stride, ::self.stride]
mask_pad = mask_pad[:, :, ::self.stride]
L = mask_pad.size(-1)
# For JIT exporting, we remove F.pad operator.
if L - T < 0:
xs = xs[:, :L - T, :].contiguous()
else:
dummy_pad = torch.zeros(B, L - T, D, device=xs.device)
xs = torch.cat([xs, dummy_pad], dim=1)
xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc')
return xs, xs_lens, mask, mask_pad
from typing import Optional
import torch
from torch import nn
from typeguard import check_argument_types
from wenet.utils.common import get_activation
class TransducerJoint(torch.nn.Module):
def __init__(self,
voca_size: int,
enc_output_size: int,
pred_output_size: int,
join_dim: int,
prejoin_linear: bool = True,
postjoin_linear: bool = False,
joint_mode: str = 'add',
activation: str = "tanh"):
assert check_argument_types()
# TODO(Mddct): concat in future
assert joint_mode in ['add']
super().__init__()
self.activatoin = get_activation(activation)
self.prejoin_linear = prejoin_linear
self.postjoin_linear = postjoin_linear
self.joint_mode = joint_mode
if not self.prejoin_linear and not self.postjoin_linear:
assert enc_output_size == pred_output_size == join_dim
# torchscript compatibility
self.enc_ffn: Optional[nn.Linear] = None
self.pred_ffn: Optional[nn.Linear] = None
if self.prejoin_linear:
self.enc_ffn = nn.Linear(enc_output_size, join_dim)
self.pred_ffn = nn.Linear(pred_output_size, join_dim)
# torchscript compatibility
self.post_ffn: Optional[nn.Linear] = None
if self.postjoin_linear:
self.post_ffn = nn.Linear(join_dim, join_dim)
self.ffn_out = nn.Linear(join_dim, voca_size)
def forward(self, enc_out: torch.Tensor, pred_out: torch.Tensor):
"""
Args:
enc_out (torch.Tensor): [B, T, E]
pred_out (torch.Tensor): [B, T, P]
Return:
[B,T,U,V]
"""
if (self.prejoin_linear and self.enc_ffn is not None
and self.pred_ffn is not None):
enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,V]
pred_out = self.pred_ffn(pred_out)
enc_out = enc_out.unsqueeze(2) # [B,T,V] -> [B,T,1,V]
pred_out = pred_out.unsqueeze(1) # [B,U,V] -> [B,1 U, V]
# TODO(Mddct): concat joint
_ = self.joint_mode
out = enc_out + pred_out # [B,T,U,V]
if self.postjoin_linear and self.post_ffn is not None:
out = self.post_ffn(out)
out = self.activatoin(out)
out = self.ffn_out(out)
return out
from typing import List, Optional, Tuple
import torch
from torch import nn
from typeguard import check_argument_types
from wenet.utils.common import get_activation, get_rnn
def ApplyPadding(input, padding, pad_value) -> torch.Tensor:
"""
Args:
input: [bs, max_time_step, dim]
padding: [bs, max_time_step]
"""
return padding * pad_value + input * (1 - padding)
class PredictorBase(torch.nn.Module):
# NOTE(Mddct): We can use ABC abstract here, but
# keep this class simple enough for now
def __init__(self) -> None:
super().__init__()
def init_state(self,
batch_size: int,
device: torch.device,
method: str = "zero") -> List[torch.Tensor]:
_, _, _ = batch_size, method, device
raise NotImplementedError("this is a base precictor")
def batch_to_cache(self,
cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
_ = cache
raise NotImplementedError("this is a base precictor")
def cache_to_batch(self,
cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
_ = cache
raise NotImplementedError("this is a base precictor")
def forward(
self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
):
_, _, = input, cache
raise NotImplementedError("this is a base precictor")
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor,
cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
_, _, _, = input, padding, cache
raise NotImplementedError("this is a base precictor")
class RNNPredictor(PredictorBase):
def __init__(self,
voca_size: int,
embed_size: int,
output_size: int,
embed_dropout: float,
hidden_size: int,
num_layers: int,
bias: bool = True,
rnn_type: str = "lstm",
dropout: float = 0.1) -> None:
assert check_argument_types()
super().__init__()
self.n_layers = num_layers
self.hidden_size = hidden_size
# disable rnn base out projection
self.embed = nn.Embedding(voca_size, embed_size)
self.dropout = nn.Dropout(embed_dropout)
# NOTE(Mddct): rnn base from torch not support layer norm
# will add layer norm and prune value in cell and layer
# ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py
self.rnn = get_rnn(rnn_type=rnn_type)(input_size=embed_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout)
self.projection = nn.Linear(hidden_size, output_size)
def forward(
self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): [batch, max_time).
padding (torch.Tensor): [batch, max_time]
cache : rnn predictor cache[0] == state_m
cache[1] == state_c
Returns:
output: [batch, max_time, output_size]
"""
# NOTE(Mddct): we don't use pack input format
embed = self.embed(input) # [batch, max_time, emb_size]
embed = self.dropout(embed)
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
if cache is None:
state = self.init_state(batch_size=input.size(0),
device=input.device)
states = (state[0], state[1])
else:
assert len(cache) == 2
states = (cache[0], cache[1])
out, (m, c) = self.rnn(embed, states)
out = self.projection(out)
# NOTE(Mddct): Although we don't use staate in transducer
# training forward, we need make it right for padding value
# so we create forward_step for infering, forward for training
_, _ = m, c
return out
def batch_to_cache(self,
cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache: [state_m, state_c]
state_ms: [1*n_layers, bs, ...]
state_cs: [1*n_layers, bs, ...]
Returns:
new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...]
"""
assert len(cache) == 2
state_ms = cache[0]
state_cs = cache[1]
assert state_ms.size(1) == state_cs.size(1)
new_cache: List[List[torch.Tensor]] = []
for state_m, state_c in zip(torch.split(state_ms, 1, dim=1),
torch.split(state_cs, 1, dim=1)):
new_cache.append([state_m, state_c])
return new_cache
def cache_to_batch(self,
cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...]
Returns:
new_caceh: [state_ms, state_cs],
state_ms: [1*n_layers, bs, ...]
state_cs: [1*n_layers, bs, ...]
"""
state_ms = torch.cat([states[0] for states in cache], dim=1)
state_cs = torch.cat([states[1] for states in cache], dim=1)
return [state_ms, state_cs]
def init_state(
self,
batch_size: int,
device: torch.device,
method: str = "zero",
) -> List[torch.Tensor]:
assert batch_size > 0
# TODO(Mddct): xavier init method
_ = method
return [
torch.zeros(1 * self.n_layers,
batch_size,
self.hidden_size,
device=device),
torch.zeros(1 * self.n_layers,
batch_size,
self.hidden_size,
device=device)
]
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor,
cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache : rnn predictor cache[0] == state_m
cache[1] == state_c
"""
assert len(cache) == 2
state_m, state_c = cache[0], cache[1]
embed = self.embed(input) # [batch, 1, emb_size]
embed = self.dropout(embed)
out, (m, c) = self.rnn(embed, (state_m, state_c))
out = self.projection(out)
m = ApplyPadding(m, padding.unsqueeze(0), state_m)
c = ApplyPadding(c, padding.unsqueeze(0), state_c)
return (out, [m, c])
class EmbeddingPredictor(PredictorBase):
"""Embedding predictor
Described in:
https://arxiv.org/pdf/2109.07513.pdf
embed-> proj -> layer norm -> swish
"""
def __init__(self,
voca_size: int,
embed_size: int,
embed_dropout: float,
n_head: int,
history_size: int = 2,
activation: str = "swish",
bias: bool = False,
layer_norm_epsilon: float = 1e-5) -> None:
assert check_argument_types()
super().__init__()
# multi head
self.num_heads = n_head
self.embed_size = embed_size
self.context_size = history_size + 1
self.pos_embed = torch.nn.Linear(embed_size * self.context_size,
self.num_heads,
bias=bias)
self.embed = nn.Embedding(voca_size, self.embed_size)
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.ffn = nn.Linear(self.embed_size, self.embed_size)
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
def init_state(self,
batch_size: int,
device: torch.device,
method: str = "zero") -> List[torch.Tensor]:
assert batch_size > 0
_ = method
return [
torch.zeros(batch_size,
self.context_size - 1,
self.embed_size,
device=device),
]
def batch_to_cache(self,
cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache : [history]
history: [bs, ...]
Returns:
new_ache : [[history_1], [history_2], [history_3]...]
"""
assert len(cache) == 1
cache_0 = cache[0]
history: List[List[torch.Tensor]] = []
for h in torch.split(cache_0, 1, dim=0):
history.append([h])
return history
def cache_to_batch(self,
cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[history_1], [history_2], [history3]...]
Returns:
new_caceh: [history],
history: [bs, ...]
"""
history = torch.cat([h[0] for h in cache], dim=0)
return [history]
def forward(self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None):
""" forward for training
"""
input = self.embed(input) # [bs, seq_len, embed]
input = self.embed_dropout(input)
if cache is None:
zeros = self.init_state(input.size(0), device=input.device)[0]
else:
assert len(cache) == 1
zeros = cache[0]
input = torch.cat((zeros, input),
dim=1) # [bs, context_size-1 + seq_len, embed]
input = input.unfold(1, self.context_size, 1).permute(
0, 1, 3, 2) # [bs, seq_len, context_size, embed]
# multi head pos: [n_head, embed, context_size]
multi_head_pos = self.pos_embed.weight.view(self.num_heads,
self.embed_size,
self.context_size)
# broadcast dot attenton
input_expand = input.unsqueeze(
2) # [bs, seq_len, 1, context_size, embed]
multi_head_pos = multi_head_pos.permute(
0, 2, 1) # [num_heads, context_size, embed]
# [bs, seq_len, num_heads, context_size, embed]
weight = input_expand * multi_head_pos
weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
3) # [bs, seq_len, num_heads, 1, context_size]
output = weight.matmul(input_expand).squeeze(
dim=3) # [bs, seq_len, num_heads, embed]
output = output.sum(dim=2) # [bs, seq_len, embed]
output = output / (self.num_heads * self.context_size)
output = self.ffn(output)
output = self.norm(output)
output = self.activatoin(output)
return output
def forward_step(
self,
input: torch.Tensor,
padding: torch.Tensor,
cache: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
""" forward step for inference
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache: for embedding predictor, cache[0] == history
"""
assert input.size(1) == 1
assert len(cache) == 1
history = cache[0]
assert history.size(1) == self.context_size - 1
input = self.embed(input) # [bs, 1, embed]
input = self.embed_dropout(input)
context_input = torch.cat((history, input), dim=1)
input_expand = context_input.unsqueeze(1).unsqueeze(
2) # [bs, 1, 1, context_size, embed]
# multi head pos: [n_head, embed, context_size]
multi_head_pos = self.pos_embed.weight.view(self.num_heads,
self.embed_size,
self.context_size)
multi_head_pos = multi_head_pos.permute(
0, 2, 1) # [num_heads, context_size, embed]
# [bs, 1, num_heads, context_size, embed]
weight = input_expand * multi_head_pos
weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
3) # [bs, 1, num_heads, 1, context_size]
output = weight.matmul(input_expand).squeeze(
dim=3) # [bs, 1, num_heads, embed]
output = output.sum(dim=2) # [bs, 1, embed]
output = output / (self.num_heads * self.context_size)
output = self.ffn(output)
output = self.norm(output)
output = self.activatoin(output)
new_cache = context_input[:, 1:, :]
# TODO(Mddct): we need padding new_cache in future
# new_cache = ApplyPadding(history, padding, new_cache)
return (output, [new_cache])
class ConvPredictor(PredictorBase):
def __init__(self,
voca_size: int,
embed_size: int,
embed_dropout: float,
history_size: int = 2,
activation: str = "relu",
bias: bool = False,
layer_norm_epsilon: float = 1e-5) -> None:
assert check_argument_types()
super().__init__()
assert history_size >= 0
self.embed_size = embed_size
self.context_size = history_size + 1
self.embed = nn.Embedding(voca_size, self.embed_size)
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.conv = nn.Conv1d(in_channels=embed_size,
out_channels=embed_size,
kernel_size=self.context_size,
padding=0,
groups=embed_size,
bias=bias)
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
def init_state(self,
batch_size: int,
device: torch.device,
method: str = "zero") -> List[torch.Tensor]:
assert batch_size > 0
assert method == "zero"
return [
torch.zeros(batch_size,
self.context_size - 1,
self.embed_size,
device=device)
]
def cache_to_batch(self,
cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[history_1], [history_2], [history3]...]
Returns:
new_caceh: [history],
history: [bs, ...]
"""
history = torch.cat([h[0] for h in cache], dim=0)
return [history]
def batch_to_cache(self,
cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache : [history]
history: [bs, ...]
Returns:
new_ache : [[history_1], [history_2], [history_3]...]
"""
assert len(cache) == 1
cache_0 = cache[0]
history: List[List[torch.Tensor]] = []
for h in torch.split(cache_0, 1, dim=0):
history.append([h])
return history
def forward(self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None):
""" forward for training
"""
input = self.embed(input) # [bs, seq_len, embed]
input = self.embed_dropout(input)
if cache is None:
zeros = self.init_state(input.size(0), device=input.device)[0]
else:
assert len(cache) == 1
zeros = cache[0]
input = torch.cat((zeros, input),
dim=1) # [bs, context_size-1 + seq_len, embed]
input = input.permute(0, 2, 1)
out = self.conv(input).permute(0, 2, 1)
out = self.activatoin(self.norm(out))
return out
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor,
cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
""" forward step for inference
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache: for embedding predictor, cache[0] == history
"""
assert input.size(1) == 1
assert len(cache) == 1
history = cache[0]
assert history.size(1) == self.context_size - 1
input = self.embed(input) # [bs, 1, embed]
input = self.embed_dropout(input)
context_input = torch.cat((history, input), dim=1)
input = context_input.permute(0, 2, 1)
out = self.conv(input).permute(0, 2, 1)
out = self.activatoin(self.norm(out))
new_cache = context_input[:, 1:, :]
# TODO(Mddct): apply padding in future
return (out, [new_cache])
from typing import List
import torch
def basic_greedy_search(
model: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
n_steps: int = 64,
) -> List[List[int]]:
# fake padding
padding = torch.zeros(1, 1).to(encoder_out.device)
# sos
pred_input_step = torch.tensor([model.blank]).reshape(1, 1)
cache = model.predictor.init_state(1,
method="zero",
device=encoder_out.device)
new_cache: List[torch.Tensor] = []
t = 0
hyps = []
prev_out_nblk = True
pred_out_step = None
per_frame_max_noblk = n_steps
per_frame_noblk = 0
while t < encoder_out_lens:
encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E]
if prev_out_nblk:
step_outs = model.predictor.forward_step(pred_input_step, padding,
cache) # [1, 1, P]
pred_out_step, new_cache = step_outs[0], step_outs[1]
joint_out_step = model.joint(encoder_out_step,
pred_out_step) # [1,1,v]
joint_out_probs = joint_out_step.log_softmax(dim=-1)
joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # []
if joint_out_max != model.blank:
hyps.append(joint_out_max.item())
prev_out_nblk = True
per_frame_noblk = per_frame_noblk + 1
pred_input_step = joint_out_max.reshape(1, 1)
# state_m, state_c = clstate_out_m, state_out_c
cache = new_cache
if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk:
if joint_out_max == model.blank:
prev_out_nblk = False
# TODO(Mddct): make t in chunk for streamming
# or t should't be too lang to predict none blank
t = t + 1
per_frame_noblk = 0
return [hyps]
from typing import List, Tuple
import torch
from wenet.utils.common import log_add
class Sequence():
__slots__ = {'hyp', 'score', 'cache'}
def __init__(
self,
hyp: List[torch.Tensor],
score,
cache: List[torch.Tensor],
):
self.hyp = hyp
self.score = score
self.cache = cache
class PrefixBeamSearch():
def __init__(self, encoder, predictor, joint, ctc, blank):
self.encoder = encoder
self.predictor = predictor
self.joint = joint
self.ctc = ctc
self.blank = blank
def forward_decoder_one_step(
self, encoder_x: torch.Tensor, pre_t: torch.Tensor,
cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device)
pre_t, new_cache = self.predictor.forward_step(pre_t.unsqueeze(-1),
padding, cache)
x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab]
x = x.log_softmax(dim=-1)
return x, new_cache
def prefix_beam_search(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
beam_size: int = 5,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
ctc_weight: float = 0.3,
transducer_weight: float = 0.7):
"""prefix beam search
also see wenet.transducer.transducer.beam_search
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
batch_size = speech.shape[0]
assert batch_size == 1
# 1. Encoder
encoder_out, _ = self.encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0)
beam_init: List[Sequence] = []
# 2. init beam using Sequence to save beam unit
cache = self.predictor.init_state(1, method="zero", device=device)
beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache))
# 3. start decoding (notice: we use breathwise first searching)
# !!!! In this decoding method: one frame do not output multi units. !!!!
# !!!! Experiments show that this strategy has little impact !!!!
for i in range(maxlen):
# 3.1 building input
# decoder taking the last token to predict the next token
input_hyp = [s.hyp[-1] for s in beam_init]
input_hyp_tensor = torch.tensor(input_hyp,
dtype=torch.int,
device=device)
# building statement from beam
cache_batch = self.predictor.cache_to_batch(
[s.cache for s in beam_init])
# build score tensor to do torch.add() function
scores = torch.tensor([s.score for s in beam_init]).to(device)
# 3.2 forward decoder
logp, new_cache = self.forward_decoder_one_step(
encoder_out[:, i, :].unsqueeze(1),
input_hyp_tensor,
cache_batch,
) # logp: (N, 1, 1, vocab_size)
logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size)
new_cache = self.predictor.batch_to_cache(new_cache)
# 3.3 shallow fusion for transducer score
# and ctc score where we can also add the LM score
logp = torch.log(
torch.add(transducer_weight * torch.exp(logp),
ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0))))
# 3.4 first beam prune
top_k_logp, top_k_index = logp.topk(beam_size) # (N, N)
scores = torch.add(scores.unsqueeze(1), top_k_logp)
# 3.5 generate new beam (N*N)
beam_A = []
for j in range(len(beam_init)):
# update seq
base_seq = beam_init[j]
for t in range(beam_size):
# blank: only update the score
if top_k_index[j, t] == self.blank:
new_seq = Sequence(hyp=base_seq.hyp.copy(),
score=scores[j, t].item(),
cache=base_seq.cache)
beam_A.append(new_seq)
# other unit: update hyp score statement and last
else:
hyp_new = base_seq.hyp.copy()
hyp_new.append(top_k_index[j, t].item())
new_seq = Sequence(hyp=hyp_new,
score=scores[j, t].item(),
cache=new_cache[j])
beam_A.append(new_seq)
# 3.6 prefix fusion
fusion_A = [beam_A[0]]
for j in range(1, len(beam_A)):
s1 = beam_A[j]
if_do_append = True
for t in range(len(fusion_A)):
# notice: A_ can not fusion with A
if s1.hyp == fusion_A[t].hyp:
fusion_A[t].score = log_add(
[fusion_A[t].score, s1.score])
if_do_append = False
break
if if_do_append:
fusion_A.append(s1)
# 4. second pruned
fusion_A.sort(key=lambda x: x.score, reverse=True)
beam_init = fusion_A[:beam_size]
return beam_init, encoder_out
from typing import Dict, List, Optional, Tuple, Union
import torch
import torchaudio
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
from wenet.transducer.predictor import PredictorBase
from wenet.transducer.search.greedy_search import basic_greedy_search
from wenet.transducer.search.prefix_beam_search import PrefixBeamSearch
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.utils.common import (IGNORE_ID, add_blank, add_sos_eos,
reverse_pad_list)
class Transducer(ASRModel):
"""Transducer-ctc-attention hybrid Encoder-Predictor-Decoder model"""
def __init__(
self,
vocab_size: int,
blank: int,
encoder: nn.Module,
predictor: PredictorBase,
joint: nn.Module,
attention_decoder: Optional[Union[TransformerDecoder,
BiTransformerDecoder]] = None,
ctc: Optional[CTC] = None,
ctc_weight: float = 0,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
transducer_weight: float = 1.0,
attention_weight: float = 0.0,
) -> None:
assert check_argument_types()
assert attention_weight + ctc_weight + transducer_weight == 1.0
super().__init__(vocab_size, encoder, attention_decoder, ctc,
ctc_weight, ignore_id, reverse_weight, lsm_weight,
length_normalized_loss)
self.blank = blank
self.transducer_weight = transducer_weight
self.attention_decoder_weight = 1 - self.transducer_weight - self.ctc_weight
self.predictor = predictor
self.joint = joint
self.bs = None
# Note(Mddct): decoder also means predictor in transducer,
# but here decoder is attention decoder
del self.criterion_att
if attention_decoder is not None:
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + predictor + joint + loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# predictor
ys_in_pad = add_blank(text, self.blank, self.ignore_id)
predictor_out = self.predictor(ys_in_pad)
# joint
joint_out = self.joint(encoder_out, predictor_out)
# NOTE(Mddct): some loss implementation require pad valid is zero
# torch.int32 rnnt_loss required
rnnt_text = text.to(torch.int64)
rnnt_text = torch.where(rnnt_text == self.ignore_id, 0,
rnnt_text).to(torch.int32)
rnnt_text_lengths = text_lengths.to(torch.int32)
encoder_out_lens = encoder_out_lens.to(torch.int32)
loss = torchaudio.functional.rnnt_loss(joint_out,
rnnt_text,
encoder_out_lens,
rnnt_text_lengths,
blank=self.blank,
reduction="mean")
loss_rnnt = loss
loss = self.transducer_weight * loss
# optional attention decoder
loss_att: Optional[torch.Tensor] = None
if self.attention_decoder_weight != 0.0 and self.decoder is not None:
loss_att, _ = self._calc_att_loss(encoder_out, encoder_mask, text,
text_lengths)
# optional ctc
loss_ctc: Optional[torch.Tensor] = None
if self.ctc_weight != 0.0 and self.ctc is not None:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None
if loss_ctc is not None:
loss = loss + self.ctc_weight * loss_ctc.sum()
if loss_att is not None:
loss = loss + self.attention_decoder_weight * loss_att.sum()
# NOTE: 'loss' must be in dict
return {
'loss': loss,
'loss_att': loss_att,
'loss_ctc': loss_ctc,
'loss_rnnt': loss_rnnt,
}
def init_bs(self):
if self.bs is None:
self.bs = PrefixBeamSearch(self.encoder, self.predictor,
self.joint, self.ctc, self.blank)
def _cal_transducer_score(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
hyps_lens: torch.Tensor,
hyps_pad: torch.Tensor,
):
# ignore id -> blank, add blank at head
hyps_pad_blank = add_blank(hyps_pad, self.blank, self.ignore_id)
xs_in_lens = encoder_mask.squeeze(1).sum(1).int()
# 1. Forward predictor
predictor_out = self.predictor(hyps_pad_blank)
# 2. Forward joint
joint_out = self.joint(encoder_out, predictor_out)
rnnt_text = hyps_pad.to(torch.int64)
rnnt_text = torch.where(rnnt_text == self.ignore_id, 0,
rnnt_text).to(torch.int32)
# 3. Compute transducer loss
loss_td = torchaudio.functional.rnnt_loss(joint_out,
rnnt_text,
xs_in_lens,
hyps_lens.int(),
blank=self.blank,
reduction='none')
return loss_td * -1
def _cal_attn_score(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
hyps_pad: torch.Tensor,
hyps_lens: torch.Tensor,
):
# (beam_size, max_hyps_len)
ori_hyps_pad = hyps_pad
# td_score = loss_td * -1
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
self.reverse_weight) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out.cpu().numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out.cpu().numpy()
return decoder_out, r_decoder_out
def beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
beam_size: int = 5,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
ctc_weight: float = 0.3,
transducer_weight: float = 0.7,
):
"""beam search
Args:
speech (torch.Tensor): (batch=1, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
ctc_weight (float): ctc probability weight in transducer
prefix beam search.
final_prob = ctc_weight * ctc_prob + transducer_weight * transducer_prob
transducer_weight (float): transducer probability weight in
prefix beam search
Returns:
List[List[int]]: best path result
"""
self.init_bs()
beam, _ = self.bs.prefix_beam_search(
speech,
speech_lengths,
decoding_chunk_size,
beam_size,
num_decoding_left_chunks,
simulate_streaming,
ctc_weight,
transducer_weight,
)
return beam[0].hyp[1:], beam[0].score
def transducer_attention_rescoring(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
reverse_weight: float = 0.0,
ctc_weight: float = 0.0,
attn_weight: float = 0.0,
transducer_weight: float = 0.0,
search_ctc_weight: float = 1.0,
search_transducer_weight: float = 0.0,
beam_search_type: str = 'transducer') -> List[List[int]]:
"""beam search
Args:
speech (torch.Tensor): (batch=1, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
ctc_weight (float): ctc probability weight using in rescoring.
rescore_prob = ctc_weight * ctc_prob +
transducer_weight * (transducer_loss * -1) +
attn_weight * attn_prob
attn_weight (float): attn probability weight using in rescoring.
transducer_weight (float): transducer probability weight using in
rescoring
search_ctc_weight (float): ctc weight using
in rnnt beam search (seeing in self.beam_search)
search_transducer_weight (float): transducer weight using
in rnnt beam search (seeing in self.beam_search)
Returns:
List[List[int]]: best path result
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.device
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
self.init_bs()
if beam_search_type == 'transducer':
beam, encoder_out = self.bs.prefix_beam_search(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
beam_size=beam_size,
num_decoding_left_chunks=num_decoding_left_chunks,
ctc_weight=search_ctc_weight,
transducer_weight=search_transducer_weight,
)
beam_score = [s.score for s in beam]
hyps = [s.hyp[1:] for s in beam]
elif beam_search_type == 'ctc':
hyps, encoder_out = self._ctc_prefix_beam_search(
speech,
speech_lengths,
beam_size=beam_size,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
beam_score = [hyp[1] for hyp in hyps]
hyps = [hyp[0] for hyp in hyps]
assert len(hyps) == beam_size
# build hyps and encoder output
hyps_pad = pad_sequence([
torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
hyps_lens = torch.tensor([len(hyp) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = torch.ones(beam_size,
1,
encoder_out.size(1),
dtype=torch.bool,
device=device)
# 2.1 calculate transducer score
td_score = self._cal_transducer_score(
encoder_out,
encoder_mask,
hyps_lens,
hyps_pad,
)
# 2.2 calculate attention score
decoder_out, r_decoder_out = self._cal_attn_score(
encoder_out,
encoder_mask,
hyps_pad,
hyps_lens,
)
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp):
score += decoder_out[i][j][w]
score += decoder_out[i][len(hyp)][self.eos]
td_s = td_score[i]
# add right to left decoder score
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp):
r_score += r_decoder_out[i][len(hyp) - j - 1][w]
r_score += r_decoder_out[i][len(hyp)][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score
score = score * attn_weight + \
beam_score[i] * ctc_weight + \
td_s * transducer_weight
if score > best_score:
best_score = score
best_index = i
return hyps[best_index], best_score
def greedy_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
n_steps: int = 64,
) -> List[List[int]]:
""" greedy search
Args:
speech (torch.Tensor): (batch=1, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: best path result
"""
# TODO(Mddct): batch decode
assert speech.size(0) == 1
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
# TODO(Mddct): forward chunk by chunk
_ = simulate_streaming
# Let's assume B = batch_size
encoder_out, encoder_mask = self.encoder(
speech,
speech_lengths,
decoding_chunk_size,
num_decoding_left_chunks,
)
encoder_out_lens = encoder_mask.squeeze(1).sum()
hyps = basic_greedy_search(self,
encoder_out,
encoder_out_lens,
n_steps=n_steps)
return hyps
@torch.jit.export
def forward_encoder_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
@torch.jit.export
def forward_predictor_step(
self, xs: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
assert len(cache) == 2
# fake padding
padding = torch.zeros(1, 1)
return self.predictor.forward_step(xs, padding, cache)
@torch.jit.export
def forward_joint_step(self, enc_out: torch.Tensor,
pred_out: torch.Tensor) -> torch.Tensor:
return self.joint(enc_out, pred_out)
@torch.jit.export
def forward_predictor_init_state(self) -> List[torch.Tensor]:
return self.predictor.init_state(1, device=torch.device("cpu"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment