Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VITA_pytorch
Commits
112bf76b
Commit
112bf76b
authored
Oct 31, 2024
by
chenzk
Browse files
v1.0
parents
Pipeline
#1826
canceled with stages
Changes
171
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2015 additions
and
0 deletions
+2015
-0
VITA/model/multimodal_encoder/siglip/__pycache__/siglip_encoder.cpython-310.pyc
...encoder/siglip/__pycache__/siglip_encoder.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/siglip/siglip_encoder.py
VITA/model/multimodal_encoder/siglip/siglip_encoder.py
+149
-0
VITA/model/multimodal_encoder/whale/__pycache__/adapter.cpython-310.pyc
...timodal_encoder/whale/__pycache__/adapter.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/__pycache__/cmvn.cpython-310.pyc
...multimodal_encoder/whale/__pycache__/cmvn.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/__pycache__/init_model.cpython-310.pyc
...odal_encoder/whale/__pycache__/init_model.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/__pycache__/utils.cpython-310.pyc
...ultimodal_encoder/whale/__pycache__/utils.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/adapter.py
VITA/model/multimodal_encoder/whale/adapter.py
+136
-0
VITA/model/multimodal_encoder/whale/cmvn.py
VITA/model/multimodal_encoder/whale/cmvn.py
+89
-0
VITA/model/multimodal_encoder/whale/init_model.py
VITA/model/multimodal_encoder/whale/init_model.py
+178
-0
VITA/model/multimodal_encoder/whale/module/component/__pycache__/mamba.cpython-310.pyc
.../whale/module/component/__pycache__/mamba.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/module/component/__pycache__/subsampling.cpython-310.pyc
.../module/component/__pycache__/subsampling.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/module/component/__pycache__/transformer.cpython-310.pyc
.../module/component/__pycache__/transformer.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/module/component/mamba.py
.../model/multimodal_encoder/whale/module/component/mamba.py
+131
-0
VITA/model/multimodal_encoder/whale/module/component/subsampling.py
.../multimodal_encoder/whale/module/component/subsampling.py
+74
-0
VITA/model/multimodal_encoder/whale/module/component/transformer.py
.../multimodal_encoder/whale/module/component/transformer.py
+428
-0
VITA/model/multimodal_encoder/whale/module/encoder/__pycache__/encoder.cpython-310.pyc
.../whale/module/encoder/__pycache__/encoder.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/module/encoder/encoder.py
.../model/multimodal_encoder/whale/module/encoder/encoder.py
+171
-0
VITA/model/multimodal_encoder/whale/module/layer/__pycache__/attention.cpython-310.pyc
.../whale/module/layer/__pycache__/attention.cpython-310.pyc
+0
-0
VITA/model/multimodal_encoder/whale/module/layer/attention.py
.../model/multimodal_encoder/whale/module/layer/attention.py
+571
-0
VITA/model/multimodal_encoder/whale/module/layer/conv1d.py
VITA/model/multimodal_encoder/whale/module/layer/conv1d.py
+88
-0
No files found.
VITA/model/multimodal_encoder/siglip/__pycache__/siglip_encoder.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/siglip/siglip_encoder.py
0 → 100644
View file @
112bf76b
import
torch
import
torch.nn
as
nn
from
transformers
import
SiglipImageProcessor
,
SiglipVisionConfig
,
SiglipVisionModel
from
vita.util.s2wrapper
import
forward
as
multiscale_forward
class
SiglipVisionTower
(
nn
.
Module
):
def
__init__
(
self
,
vision_tower
,
args
,
delay_load
=
False
):
super
().
__init__
()
self
.
is_loaded
=
False
self
.
vision_tower_name
=
vision_tower
self
.
select_layer
=
-
2
if
not
delay_load
:
self
.
load_model
()
else
:
self
.
cfg_only
=
SiglipVisionConfig
.
from_pretrained
(
self
.
vision_tower_name
)
def
load_model
(
self
):
self
.
image_processor
=
SiglipImageProcessor
.
from_pretrained
(
self
.
vision_tower_name
)
self
.
image_processor
.
crop_size
=
self
.
image_processor
.
size
self
.
vision_tower
=
SiglipVisionModel
.
from_pretrained
(
self
.
vision_tower_name
)
self
.
vision_tower
.
requires_grad_
(
False
)
self
.
is_loaded
=
True
def
feature_select
(
self
,
image_forward_outs
):
image_features
=
image_forward_outs
.
hidden_states
[
self
.
select_layer
]
return
image_features
@
torch
.
no_grad
()
def
forward
(
self
,
images
):
if
type
(
images
)
is
list
:
image_features
=
[]
for
image
in
images
:
image_forward_out
=
self
.
vision_tower
(
image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
).
unsqueeze
(
0
),
output_hidden_states
=
True
,
)
image_feature
=
self
.
feature_select
(
image_forward_out
).
to
(
image
.
dtype
)
image_features
.
append
(
image_feature
)
else
:
image_forward_outs
=
self
.
vision_tower
(
images
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
),
output_hidden_states
=
True
)
image_features
=
self
.
feature_select
(
image_forward_outs
).
to
(
images
.
dtype
)
return
image_features
@
property
def
dummy_feature
(
self
):
return
torch
.
zeros
(
1
,
self
.
hidden_size
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
@
property
def
dtype
(
self
):
return
self
.
vision_tower
.
dtype
@
property
def
device
(
self
):
return
self
.
vision_tower
.
device
@
property
def
config
(
self
):
if
self
.
is_loaded
:
return
self
.
vision_tower
.
config
else
:
return
self
.
cfg_only
@
property
def
hidden_size
(
self
):
return
self
.
config
.
hidden_size
@
property
def
num_patches
(
self
):
return
(
self
.
config
.
image_size
//
self
.
config
.
patch_size
)
**
2
class
SiglipVisionTowerS2
(
SiglipVisionTower
):
def
__init__
(
self
,
vision_tower
,
args
,
delay_load
=
False
):
self
.
s2_scales
=
getattr
(
args
,
"s2_scales"
,
"384,768,1152"
)
self
.
s2_scales
=
list
(
map
(
int
,
self
.
s2_scales
.
split
(
","
)))
self
.
s2_scales
.
sort
()
self
.
s2_split_size
=
self
.
s2_scales
[
0
]
self
.
s2_image_size
=
self
.
s2_scales
[
-
1
]
super
().
__init__
(
vision_tower
,
args
,
delay_load
)
self
.
multiscale_forward
=
multiscale_forward
if
not
delay_load
:
self
.
image_processor
.
size
[
"height"
]
=
self
.
image_processor
.
size
[
"width"
]
=
self
.
s2_image_size
self
.
image_processor
.
crop_size
[
"height"
]
=
self
.
image_processor
.
crop_size
[
"width"
]
=
self
.
s2_image_size
def
load_model
(
self
):
self
.
image_processor
=
SiglipImageProcessor
.
from_pretrained
(
self
.
vision_tower_name
)
self
.
image_processor
.
crop_size
=
self
.
image_processor
.
size
self
.
vision_tower
=
SiglipVisionModel
.
from_pretrained
(
self
.
vision_tower_name
)
self
.
vision_tower
.
requires_grad_
(
False
)
self
.
image_processor
.
size
[
"height"
]
=
self
.
image_processor
.
size
[
"width"
]
=
self
.
s2_image_size
self
.
image_processor
.
crop_size
[
"height"
]
=
self
.
image_processor
.
crop_size
[
"width"
]
=
self
.
s2_image_size
self
.
is_loaded
=
True
@
torch
.
no_grad
()
def
forward_feature
(
self
,
images
):
image_forward_outs
=
self
.
vision_tower
(
images
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
),
output_hidden_states
=
True
)
image_features
=
self
.
feature_select
(
image_forward_outs
).
to
(
images
.
dtype
)
return
image_features
@
torch
.
no_grad
()
def
forward
(
self
,
images
):
if
type
(
images
)
is
list
:
image_features
=
[]
for
image
in
images
:
image_feature
=
self
.
multiscale_forward
(
self
.
forward_feature
,
image
.
unsqueeze
(
0
),
img_sizes
=
self
.
s2_scales
,
max_split_size
=
self
.
s2_split_size
,
)
image_features
.
append
(
image_feature
)
else
:
image_features
=
self
.
multiscale_forward
(
self
.
forward_feature
,
images
,
img_sizes
=
self
.
s2_scales
,
max_split_size
=
self
.
s2_split_size
,
)
return
image_features
@
property
def
hidden_size
(
self
):
return
self
.
config
.
hidden_size
*
len
(
self
.
s2_scales
)
VITA/model/multimodal_encoder/whale/__pycache__/adapter.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/__pycache__/cmvn.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/__pycache__/init_model.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/__pycache__/utils.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/adapter.py
0 → 100644
View file @
112bf76b
import
torch
from
torch
import
nn
from
torch.nn.utils.rnn
import
pad_sequence
class
CNNAdapter
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
enc_out_dim
:
int
=
512
,
llm_embed_dim
:
int
=
4096
,
kernel_size
:
int
=
5
,
):
super
().
__init__
()
self
.
left_padding1
=
nn
.
ConstantPad1d
((
kernel_size
-
1
,
0
),
0.0
)
self
.
conv1d1
=
nn
.
Conv1d
(
enc_out_dim
,
2
*
enc_out_dim
,
kernel_size
,
1
,
0
)
self
.
bn1
=
nn
.
BatchNorm1d
(
2
*
enc_out_dim
,
eps
=
1e-3
,
momentum
=
0.99
)
self
.
relu1
=
nn
.
ReLU
()
self
.
left_padding2
=
nn
.
ConstantPad1d
((
kernel_size
-
1
,
0
),
0.0
)
self
.
conv1d2
=
nn
.
Conv1d
(
2
*
enc_out_dim
,
4
*
enc_out_dim
,
kernel_size
,
1
,
0
)
self
.
bn2
=
nn
.
BatchNorm1d
(
4
*
enc_out_dim
,
eps
=
1e-3
,
momentum
=
0.99
)
self
.
relu2
=
nn
.
ReLU
()
self
.
project
=
nn
.
Linear
(
4
*
enc_out_dim
,
llm_embed_dim
)
def
forward
(
self
,
x
,
mask_pad
):
"""
x: B, T, enc_out_dim
mask: (B, T) or (B, 1, T)
"""
x
=
x
.
transpose
(
1
,
2
)
# B, channels, T
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
x
=
self
.
left_padding1
(
x
)
x
=
self
.
conv1d1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
self
.
left_padding2
(
x
)
x
=
self
.
conv1d2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
self
.
relu2
(
x
)
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
project
(
x
)
return
x
,
mask_pad
class
LinearAdapter
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
enc_out_dim
:
int
=
512
,
llm_embed_dim
:
int
=
4096
,
):
super
().
__init__
()
self
.
adpter
=
torch
.
nn
.
Linear
(
enc_out_dim
,
llm_embed_dim
)
def
forward
(
self
,
x
,
mask_pad
):
return
self
.
adpter
(
x
),
mask_pad
class
CNNSubsampling
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
enc_out_dim
:
int
=
512
,
llm_embed_dim
:
int
=
4096
,
kernel_size
:
int
=
5
,
activation_func
:
str
=
"relu"
,
norm
:
str
=
"batch"
,
):
super
().
__init__
()
if
enc_out_dim
*
4
<
llm_embed_dim
:
self
.
left_padding1
=
nn
.
ConstantPad1d
((
kernel_size
-
1
,
0
),
0.0
)
self
.
conv1d1
=
nn
.
Conv1d
(
enc_out_dim
,
2
*
enc_out_dim
,
kernel_size
,
1
,
0
)
self
.
bn1
=
nn
.
BatchNorm1d
(
2
*
enc_out_dim
,
eps
=
1e-3
,
momentum
=
0.99
)
self
.
relu1
=
nn
.
ReLU
()
self
.
left_padding2
=
nn
.
ConstantPad1d
((
0
,
kernel_size
-
1
),
0.0
)
self
.
conv1d2
=
nn
.
Conv1d
(
2
*
enc_out_dim
,
4
*
enc_out_dim
,
kernel_size
,
2
,
0
)
self
.
bn2
=
nn
.
BatchNorm1d
(
4
*
enc_out_dim
,
eps
=
1e-3
,
momentum
=
0.99
)
self
.
relu2
=
nn
.
ReLU
()
self
.
project
=
nn
.
Linear
(
4
*
enc_out_dim
,
llm_embed_dim
)
self
.
cnn_num
=
2
else
:
self
.
left_padding2
=
nn
.
ConstantPad1d
((
0
,
kernel_size
-
1
),
0.0
)
self
.
conv1d2
=
nn
.
Conv1d
(
enc_out_dim
,
2
*
enc_out_dim
,
kernel_size
,
2
,
0
)
if
norm
==
"batch"
:
self
.
bn2
=
nn
.
BatchNorm1d
(
2
*
enc_out_dim
,
eps
=
1e-3
,
momentum
=
0.99
)
elif
norm
==
"layer"
:
self
.
bn2
=
nn
.
LayerNorm
(
2
*
enc_out_dim
,
eps
=
1e-3
)
if
activation_func
==
"gelu"
:
self
.
relu2
=
nn
.
GELU
()
else
:
self
.
relu2
=
nn
.
ReLU
()
self
.
project
=
nn
.
Linear
(
2
*
enc_out_dim
,
llm_embed_dim
)
self
.
cnn_num
=
1
def
forward
(
self
,
x
,
mask_pad
):
"""
x: B, T, enc_out_dim
mask: (B, T) or (B, 1, T)
"""
x
=
x
.
transpose
(
1
,
2
)
# B, channels, T
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
if
self
.
cnn_num
==
2
:
x
=
self
.
left_padding1
(
x
)
x
=
self
.
conv1d1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
self
.
left_padding2
(
x
)
x
=
self
.
conv1d2
(
x
)
if
isinstance
(
self
.
bn2
,
nn
.
LayerNorm
):
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
bn2
(
x
)
if
isinstance
(
self
.
bn2
,
nn
.
LayerNorm
):
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
relu2
(
x
)
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
project
(
x
)
return
x
,
mask_pad
[:,
:,
0
::
2
]
VITA/model/multimodal_encoder/whale/cmvn.py
0 → 100644
View file @
112bf76b
import
numpy
as
np
import
torch
import
json
import
math
class
GlobalCMVN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
mean
:
torch
.
Tensor
,
istd
:
torch
.
Tensor
,
norm_var
:
bool
=
True
):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super
().
__init__
()
assert
mean
.
shape
==
istd
.
shape
self
.
norm_var
=
norm_var
# The buffer can be accessed from this module using self.mean
self
.
register_buffer
(
"mean"
,
mean
)
self
.
register_buffer
(
"istd"
,
istd
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x
=
x
-
self
.
mean
if
self
.
norm_var
:
x
=
x
*
self
.
istd
return
x
def
load_cmvn_json
(
json_cmvn_file
):
with
open
(
json_cmvn_file
)
as
f
:
cmvn_json
=
json
.
load
(
f
)
avg
=
cmvn_json
[
"mean_stat"
]
var
=
cmvn_json
[
"var_stat"
]
count
=
cmvn_json
[
"frame_num"
]
for
i
in
range
(
len
(
avg
)):
avg
[
i
]
/=
count
var
[
i
]
=
var
[
i
]
/
count
-
avg
[
i
]
*
avg
[
i
]
if
var
[
i
]
<
1.0e-20
:
var
[
i
]
=
1.0e-20
var
[
i
]
=
1.0
/
math
.
sqrt
(
var
[
i
])
cmvn
=
np
.
array
([
avg
,
var
])
return
cmvn
def
load_cmvn_kaldi
(
kaldi_cmvn_file
):
avg
=
[]
var
=
[]
with
open
(
kaldi_cmvn_file
,
"r"
)
as
file
:
# kaldi binary file start with '\0B'
if
file
.
read
(
2
)
==
"
\0
B"
:
logging
.
error
(
"kaldi cmvn binary file is not supported, please "
)
sys
.
exit
(
1
)
file
.
seek
(
0
)
arr
=
file
.
read
().
split
()
assert
arr
[
0
]
==
"["
assert
arr
[
-
2
]
==
"0"
assert
arr
[
-
1
]
==
"]"
feat_dim
=
int
((
len
(
arr
)
-
2
-
2
)
/
2
)
for
i
in
range
(
1
,
feat_dim
+
1
):
avg
.
append
(
float
(
arr
[
i
]))
count
=
float
(
arr
[
feat_dim
+
1
])
for
i
in
range
(
feat_dim
+
2
,
2
*
feat_dim
+
2
):
var
.
append
(
float
(
arr
[
i
]))
for
i
in
range
(
len
(
avg
)):
avg
[
i
]
/=
count
var
[
i
]
=
var
[
i
]
/
count
-
avg
[
i
]
*
avg
[
i
]
if
var
[
i
]
<
1.0e-20
:
var
[
i
]
=
1.0e-20
var
[
i
]
=
1.0
/
math
.
sqrt
(
var
[
i
])
cmvn
=
np
.
array
([
avg
,
var
])
return
cmvn
def
load_cmvn
(
filename
,
is_json
):
if
is_json
:
file
=
load_cmvn_json
(
filename
)
else
:
file
=
load_cmvn_kaldi
(
filename
)
return
file
[
0
],
file
[
1
]
VITA/model/multimodal_encoder/whale/init_model.py
0 → 100644
View file @
112bf76b
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
import
torchaudio
import
torchaudio.compliance.kaldi
as
kaldi
from
.adapter
import
CNNAdapter
,
CNNSubsampling
,
LinearAdapter
from
.cmvn
import
GlobalCMVN
,
load_cmvn
from
.module.encoder.encoder
import
whaleEncoder
class
audioEncoderProcessor
:
def
__init__
(
self
,
dataset_conf
:
dict
=
None
,
):
self
.
dataset_conf
=
dataset_conf
def
process
(
self
,
wav_path
):
try
:
print
(
"#################"
,
wav_path
)
waveform
,
sample_rate
=
torchaudio
.
load
(
wav_path
)
except
Exception
as
e
:
print
(
f
"cannot open
{
wav_path
}
!!!!!!!!!!!!!!!!"
)
if
sample_rate
!=
self
.
dataset_conf
[
"resample_conf"
][
"resample_rate"
]:
# sample_rate = self.dataset_conf['resample_conf']['resample_rate']
waveform
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
self
.
dataset_conf
[
"resample_conf"
][
"resample_rate"
]
)(
waveform
)
waveform
=
waveform
*
(
1
<<
15
)
# Only keep key, feat, label
mat
=
kaldi
.
fbank
(
waveform
,
num_mel_bins
=
self
.
dataset_conf
[
"fbank_conf"
][
"num_mel_bins"
],
frame_length
=
self
.
dataset_conf
[
"fbank_conf"
][
"frame_length"
],
frame_shift
=
self
.
dataset_conf
[
"fbank_conf"
][
"frame_shift"
],
dither
=
self
.
dataset_conf
[
"fbank_conf"
][
"dither"
],
energy_floor
=
0.0
,
sample_frequency
=
sample_rate
,
)
attn_mask
=
torch
.
ones
(
mat
.
shape
[
0
])
attn_mask
=
attn_mask
[
2
::
2
][
2
::
2
][
0
::
2
]
return
mat
,
attn_mask
.
shape
[
0
]
class
audioEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
encoder
:
torch
.
nn
.
Module
,
llm_path
:
str
,
freeze_llm
:
bool
=
True
,
enc_out_dim
:
int
=
512
,
llm_embed_dim
:
int
=
4096
,
kernel_size
:
int
=
3
,
IGNORE_ID
:
int
=
-
100
,
adpter_type
:
str
=
"cnn"
,
add_audio_bos_eos
:
bool
=
False
,
task_num
:
int
=
10
,
task_before_audio
:
bool
=
False
,
task_type
:
str
=
"prompt"
,
freeze_encoder
:
bool
=
False
,
freeze_adpter
:
bool
=
False
,
activation_func
:
str
=
"relu"
,
norm
:
str
=
"batch"
,
chat_template
=
None
,
):
super
().
__init__
()
self
.
encoder
=
encoder
self
.
enc_out_dim
=
enc_out_dim
self
.
llm_embed_dim
=
llm_embed_dim
self
.
IGNORE_ID
=
IGNORE_ID
self
.
add_audio_bos_eos
=
add_audio_bos_eos
self
.
task_before_audio
=
task_before_audio
self
.
task_type
=
task_type
self
.
freeze_encoder
=
freeze_encoder
self
.
freeze_adpter
=
freeze_adpter
if
adpter_type
==
"cnn"
:
self
.
adpter
=
CNNAdapter
(
enc_out_dim
,
llm_embed_dim
,
kernel_size
)
elif
adpter_type
==
"linear"
:
self
.
adpter
=
LinearAdapter
(
enc_out_dim
,
llm_embed_dim
)
elif
adpter_type
==
"subsampling"
:
self
.
adpter
=
CNNSubsampling
(
enc_out_dim
,
llm_embed_dim
,
kernel_size
,
activation_func
,
norm
)
if
self
.
freeze_encoder
:
self
.
encoder
.
eval
()
for
(
name
,
param
)
in
self
.
encoder
.
named_parameters
():
param
.
requires_grad
=
False
if
self
.
freeze_adpter
:
self
.
adpter
.
eval
()
for
(
name
,
param
)
in
self
.
adpter
.
named_parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
)
->
Dict
[
str
,
Optional
[
torch
.
Tensor
]]:
speech
=
speech
.
to
(
next
(
self
.
parameters
()).
dtype
)
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
inputs_embeds
,
encoder_mask
=
self
.
adpter
(
encoder_out
,
encoder_mask
)
# B, T, D
attention_mask
=
encoder_mask
.
squeeze
(
1
)
# B, T
assert
inputs_embeds
.
size
(
1
)
==
attention_mask
.
size
(
1
)
# audio bos/eos
if
self
.
add_audio_bos_eos
:
inputs_embeds
,
attention_mask
,
target
=
self
.
_add_bos_eos
(
"audio"
,
"/audio"
,
inputs_embeds
,
attention_mask
,
target
)
outputs
=
{
"inputs_embeds"
:
inputs_embeds
,
"attention_mask"
:
attention_mask
,
}
return
outputs
def
_add_bos_eos
(
self
,
bos
,
eos
,
inputs_embeds
,
attention_mask
,
target
=
None
):
B
=
len
(
inputs_embeds
)
bos_embed
=
self
.
task_embeddings
(
torch
.
full
([
B
,
1
],
self
.
task_ids
[
bos
]).
to
(
inputs_embeds
.
device
)
)
# B, 1, D
eos_embed
=
self
.
task_embeddings
(
torch
.
full
([
B
,
1
],
self
.
task_ids
[
eos
]).
to
(
inputs_embeds
.
device
)
)
# B, 1, D
bos_eos_target
=
torch
.
full
([
B
,
2
],
self
.
IGNORE_ID
).
to
(
inputs_embeds
.
device
)
# B, 2
bos_eos_mask
=
torch
.
full
([
B
,
1
],
True
).
to
(
inputs_embeds
.
device
)
# B, 1
inputs_embeds
=
torch
.
cat
((
bos_embed
,
inputs_embeds
),
1
)
# B, (1+T), D
inputs_embeds
=
torch
.
cat
((
inputs_embeds
,
eos_embed
),
1
)
# B, (1+T+1), D
attention_mask
=
torch
.
cat
((
bos_eos_mask
,
attention_mask
),
1
)
# B, (1+T)
attention_mask
=
torch
.
cat
((
attention_mask
,
bos_eos_mask
),
1
)
# B, (1+T+1)
if
target
is
not
None
:
target
=
torch
.
cat
((
target
,
bos_eos_target
),
1
)
# B, (T+2), D
return
inputs_embeds
,
attention_mask
,
target
def
init_model
(
configs
):
if
configs
[
"cmvn_file"
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
"cmvn_file"
],
configs
[
"is_json_cmvn"
])
global_cmvn
=
GlobalCMVN
(
torch
.
from_numpy
(
mean
).
float
(),
torch
.
from_numpy
(
istd
).
float
())
else
:
global_cmvn
=
None
input_dim
=
configs
[
"input_dim"
]
encoder
=
whaleEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
"encoder_conf"
])
model
=
audioEncoder
(
encoder
=
encoder
,
**
configs
[
"model_conf"
])
processor
=
audioEncoderProcessor
(
dataset_conf
=
configs
[
"dataset_conf"
])
model
.
audio_processor
=
processor
return
model
VITA/model/multimodal_encoder/whale/module/component/__pycache__/mamba.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/module/component/__pycache__/subsampling.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/module/component/__pycache__/transformer.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/module/component/mamba.py
0 → 100644
View file @
112bf76b
"""Encoder self-attention layer definition."""
import
math
import
pdb
from
functools
import
partial
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vita.model.multimodal_encoder.whale.utils
import
IGNORE_ID
,
strtobool
try
:
from
mamba_ssm.modules.mamba_simple
import
Mamba
,
Block
from
mamba_ssm.models.mixer_seq_simple
import
_init_weights
from
mamba_ssm.ops.triton.layernorm
import
RMSNorm
except
ImportError
:
print
(
"Please install mamba_ssm to use MambaSSM component."
)
class
MambaBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
n_layer
=
1
,
d_state
=
16
,
d_conv
=
4
,
expand
=
4
,
bidirectional
=
False
):
super
(
MambaBlock
,
self
).
__init__
()
self
.
forward_blocks
=
nn
.
ModuleList
([])
self
.
forward_norm_f
=
RMSNorm
(
in_channels
,
eps
=
1e-5
)
for
i
in
range
(
n_layer
):
self
.
forward_blocks
.
append
(
Block
(
in_channels
,
mixer_cls
=
partial
(
Mamba
,
layer_idx
=
i
,
d_state
=
d_state
,
d_conv
=
d_conv
,
expand
=
expand
),
norm_cls
=
partial
(
RMSNorm
,
eps
=
1e-5
),
fused_add_norm
=
True
,
residual_in_fp32
=
True
,
)
)
if
bidirectional
:
self
.
backward_blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
n_layer
):
self
.
backward_blocks
.
append
(
Block
(
in_channels
,
mixer_cls
=
partial
(
Mamba
,
layer_idx
=
i
,
d_state
=
d_state
,
d_conv
=
d_conv
,
expand
=
expand
),
norm_cls
=
partial
(
RMSNorm
,
eps
=
1e-5
),
fused_add_norm
=
True
,
residual_in_fp32
=
True
,
)
)
self
.
backward_norm_f
=
RMSNorm
(
in_channels
,
eps
=
1e-5
)
else
:
self
.
backward_blocks
=
None
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
n_layer
))
def
forward
(
self
,
input
):
for_residual
=
None
forward_f
=
input
.
clone
()
for
block
in
self
.
forward_blocks
:
forward_f
,
for_residual
=
block
(
forward_f
,
for_residual
,
inference_params
=
None
)
residual
=
(
forward_f
+
for_residual
)
if
for_residual
is
not
None
else
forward_f
residual
=
self
.
forward_norm_f
(
residual
)
if
self
.
backward_blocks
is
not
None
:
back_residual
=
None
backward_f
=
torch
.
flip
(
input
,
[
1
])
for
block
in
self
.
backward_blocks
:
backward_f
,
back_residual
=
block
(
backward_f
,
back_residual
,
inference_params
=
None
)
back_residual
=
(
(
backward_f
+
back_residual
)
if
back_residual
is
not
None
else
backward_f
)
back_residual
=
torch
.
flip
(
back_residual
,
[
1
])
back_residual
=
self
.
backward_norm_f
(
back_residual
)
residual
=
torch
.
cat
([
residual
,
back_residual
],
-
1
)
return
residual
class
MambaSSM
(
torch
.
nn
.
Module
):
@
staticmethod
def
add_arguments
(
group
):
"""Add TDNN common arguments."""
group
.
add_argument
(
"--mamba-num-layers"
,
default
=
4
,
type
=
int
,
help
=
"Output dim of MambaSSM."
)
group
.
add_argument
(
"--mamba-input-dim"
,
default
=
256
,
type
=
int
,
help
=
"Input dim of MambaSSM."
)
group
.
add_argument
(
"--mamba-output-dim"
,
default
=
256
,
type
=
int
,
help
=
"Output dim of MambaSSM."
)
group
.
add_argument
(
"--mamba-d-state"
,
default
=
16
,
type
=
int
,
help
=
"d-state of MambaSSM."
)
group
.
add_argument
(
"--mamba-d-conv"
,
default
=
4
,
type
=
int
,
help
=
"d-conv of MambaSSM."
)
group
.
add_argument
(
"--mamba-expand"
,
default
=
4
,
type
=
int
,
help
=
"expand of MambaSSM."
)
return
group
def
__init__
(
self
,
args
):
"""Construct an Encoder object."""
super
(
MambaSSM
,
self
).
__init__
()
self
.
mamb_num_layers
=
args
.
mamba_num_layers
self
.
mamba_input_dim
=
args
.
mamba_input_dim
self
.
mamba_output_dim
=
args
.
mamba_output_dim
self
.
mamba_d_state
=
args
.
mamba_d_state
self
.
mamba_d_conv
=
args
.
mamba_d_conv
self
.
mamba_expand
=
args
.
mamba_expand
self
.
mamba
=
MambaBlock
(
self
.
mamba_input_dim
,
self
.
mamb_num_layers
,
self
.
mamba_d_state
,
self
.
mamba_d_conv
,
self
.
mamba_expand
,
)
@
torch
.
jit
.
unused
def
forward
(
self
,
xs
,
ilens
=
None
,
masks
=
None
):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
xs_out
=
self
.
mamba
(
xs
)
return
xs_out
.
to
(
xs
.
dtype
),
ilens
,
masks
VITA/model/multimodal_encoder/whale/module/component/subsampling.py
0 → 100644
View file @
112bf76b
import
torch
from
typing
import
Tuple
,
Union
class
BaseSubsampling
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
subsampling_rate
=
1
self
.
right_context
=
0
def
position_encoding
(
self
,
offset
:
Union
[
int
,
torch
.
Tensor
],
size
:
int
)
->
torch
.
Tensor
:
return
self
.
pos_enc
.
position_encoding
(
offset
,
size
)
class
Conv2dSubsampling4
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
):
"""Construct an Conv2dSubsampling4 object."""
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
))
self
.
right_context
=
6
self
.
subsampling_rate
=
4
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_mask
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
return
x
,
x_mask
[:,
:,
2
::
2
][:,
:,
2
::
2
]
class
Subsampling
(
torch
.
nn
.
Module
):
@
staticmethod
def
add_arguments
(
group
):
"""Add Subsampling common arguments."""
group
.
add_argument
(
"--subsampling-rate"
,
default
=
4
,
type
=
int
)
group
.
add_argument
(
"--subsampling-input-dim"
,
default
=
256
,
type
=
int
)
group
.
add_argument
(
"--subsampling-output-dim"
,
default
=
256
,
type
=
int
)
group
.
add_argument
(
"--subsampling-dropout-rate"
,
default
=
0.1
,
type
=
float
)
return
group
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
subsampling_rate
=
args
.
subsampling_rate
self
.
subsampling_input_dim
=
args
.
subsampling_input_dim
self
.
subsampling_output_dim
=
args
.
subsampling_output_dim
self
.
subsampling_dropout_rate
=
args
.
subsampling_dropout_rate
if
self
.
subsampling_rate
==
4
:
self
.
core
=
Conv2dSubsampling4
(
self
.
subsampling_input_dim
,
self
.
subsampling_output_dim
,
self
.
subsampling_dropout_rate
,
)
def
forward
(
self
,
xs
,
ilens
,
masks
):
xs
,
masks
=
self
.
core
(
xs
,
masks
)
ilens
=
masks
.
squeeze
(
1
).
sum
(
1
)
return
xs
,
ilens
,
masks
VITA/model/multimodal_encoder/whale/module/component/transformer.py
0 → 100644
View file @
112bf76b
"""Encoder self-attention layer definition."""
import
math
import
pdb
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vita.model.multimodal_encoder.whale.module.layer.attention
import
(
Conv1dLinear
,
MultiHeadedAttention
,
MultiLayeredConv1d
,
PositionalEncoding
,
PositionwiseFeedForward
,
RelPositionalEncoding
,
)
# from vita.model.multimodal_encoder.whale.module.component.utils import *
from
vita.model.multimodal_encoder.whale.utils
import
IGNORE_ID
,
add_optional_chunk_mask
,
strtobool
def
repeat
(
N
,
fn
):
"""Repeat module N times.
:param int N: repeat time
:param function fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
class
MultiSequential
(
torch
.
nn
.
Sequential
):
"""Multi-input multi-output torch.nn.Sequential."""
def
forward
(
self
,
x
,
masks
,
pos_emb
):
"""Repeat."""
for
m
in
self
:
x
,
masks
,
pos_emb
=
m
(
x
,
masks
,
pos_emb
)
return
x
,
masks
,
pos_emb
@
torch
.
jit
.
export
def
infer
(
self
,
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
"""Repeat."""
for
m
in
self
:
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
=
m
.
infer
(
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
)
return
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
@
torch
.
jit
.
export
def
infer_hidden
(
self
,
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
"""Repeat."""
for
m
in
self
:
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
=
m
.
infer
(
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
)
hidden_out
.
append
(
x
)
return
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
class
TransformerLayer
(
nn
.
Module
):
"""Transformer layer module.
:param int size: input dim
:param self_attn: self attention module
:param feed_forward: feed forward module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def
__init__
(
self
,
size
,
self_attn
,
feed_forward
,
dropout_rate
,
normalize_before
=
True
,
concat_after
=
False
):
"""Construct an TransformerLayer object."""
super
(
TransformerLayer
,
self
).
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
norm1
=
torch
.
nn
.
LayerNorm
(
size
)
self
.
norm2
=
torch
.
nn
.
LayerNorm
(
size
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
if
self
.
concat_after
:
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
else
:
self
.
concat_linear
=
nn
.
Identity
()
@
torch
.
jit
.
unused
def
forward
(
self
,
x
,
mask
,
pos_emb
):
"""Compute encoded features.
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
self
.
self_attn
(
x
,
x
,
x
,
mask
,
pos_emb
)),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
x
,
x
,
x
,
mask
,
pos_emb
))
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
return
x
,
mask
,
pos_emb
@
torch
.
jit
.
export
def
infer
(
self
,
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
residual
=
x
.
clone
()
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
self
.
concat_after
:
x_att
,
buffer
,
buffer_index
,
buffer_out
=
self
.
self_attn
.
infer
(
x
,
x
,
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
)
x_concat
=
torch
.
cat
((
x
,
x_att
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x_att
,
buffer
,
buffer_index
,
buffer_out
=
self
.
self_attn
.
infer
(
x
,
x
,
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
)
x
=
residual
+
x_att
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
residual
=
x
.
clone
()
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x_feed
,
buffer
,
buffer_index
,
buffer_out
=
self
.
feed_forward
.
infer
(
x
,
buffer
,
buffer_index
,
buffer_out
)
x
=
residual
+
x_feed
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
return
x
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
class
Transformer
(
torch
.
nn
.
Module
):
@
staticmethod
def
add_arguments
(
group
):
"""Add TDNN common arguments."""
group
.
add_argument
(
"--transformer-input-dim"
,
default
=
256
,
type
=
int
,
help
=
"Input dim of Transformer."
)
group
.
add_argument
(
"--transformer-output-dim"
,
default
=
4
,
type
=
int
,
help
=
"Output dim of Transformer."
)
group
.
add_argument
(
"--transformer-attention-dim"
,
default
=
256
,
type
=
int
,
help
=
"Dimention of attention."
)
group
.
add_argument
(
"--transformer-attention-heads"
,
default
=
4
,
type
=
int
,
help
=
"The number of heads of multi head attention."
,
)
group
.
add_argument
(
"--transformer-linear-units"
,
default
=
1024
,
type
=
int
,
help
=
"The number of units of position-wise feed forward."
,
)
group
.
add_argument
(
"--transformer-num-blocks"
,
default
=
6
,
type
=
int
,
help
=
"The number of attention blocks."
)
group
.
add_argument
(
"--transformer-dropout-rate"
,
default
=
0.1
,
type
=
float
,
help
=
"Dropout rate in Transformer."
,
)
group
.
add_argument
(
"--transformer-attention-dropout-rate"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate in attention."
,
)
group
.
add_argument
(
"--transformer-positional-dropout-rate"
,
default
=
0.1
,
type
=
float
,
help
=
"Dropout rate after adding positional encoding."
,
)
group
.
add_argument
(
"--transformer-input-layer"
,
default
=
"linear"
,
type
=
str
,
help
=
"Type of input layer"
)
group
.
add_argument
(
"--transformer-pos-enc-class"
,
default
=
"abs-enc"
,
type
=
str
,
help
=
""
)
group
.
add_argument
(
"--transformer-normalize-before"
,
default
=
True
,
type
=
strtobool
,
help
=
"Whether to use layer-norm before the first block."
,
)
group
.
add_argument
(
"--transformer-concat-after"
,
default
=
False
,
type
=
strtobool
,
help
=
"Whether to concat attention layer's input and output."
,
)
group
.
add_argument
(
"--transformer-positionwise-layer-type"
,
default
=
"linear"
,
type
=
str
,
help
=
"Linear of conv1d."
,
)
group
.
add_argument
(
"--transformer-positionwise-conv-kernel_size"
,
default
=
1
,
type
=
int
,
help
=
"Kernel size of positionwise conv1d layer."
,
)
group
.
add_argument
(
"--transformer-chunk_size"
,
default
=-
1
,
type
=
int
,
help
=
""
)
group
.
add_argument
(
"--transformer-left_chunks"
,
default
=-
1
,
type
=
int
,
help
=
""
)
group
.
add_argument
(
"--transformer-dynamic-chunks"
,
default
=
True
,
type
=
strtobool
,
help
=
""
)
return
group
def
__init__
(
self
,
args
,
input_dim
=
None
,
output_dim
=
None
,
attention_dim
=
None
,
attention_heads
=
None
,
linear_units
=
None
,
num_blocks
=
None
,
dropout_rate
=
None
,
positional_dropout_rate
=
None
,
attention_dropout_rate
=
None
,
input_layer
=
None
,
pos_enc_class
=
None
,
normalize_before
=
None
,
concat_after
=
None
,
positionwise_layer_type
=
None
,
positionwise_conv_kernel_size
=
None
,
chunk_size
=
None
,
left_chunks
=
None
,
):
"""Construct an Encoder object."""
super
(
Transformer
,
self
).
__init__
()
if
args
is
None
:
self
.
input_dim
=
input_dim
self
.
output_dim
=
output_dim
self
.
attention_dim
=
attention_dim
self
.
attention_heads
=
attention_heads
self
.
linear_units
=
linear_units
self
.
num_blocks
=
num_blocks
self
.
dropout_rate
=
dropout_rate
self
.
positional_dropout_rate
=
positional_dropout_rate
self
.
attention_dropout_rate
=
attention_dropout_rate
self
.
input_layer
=
input_layer
self
.
pos_enc_class
=
pos_enc_class
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
positionwise_layer_type
=
positionwise_layer_type
self
.
positionwise_conv_kernel_size
=
positionwise_conv_kernel_size
self
.
chunk_size
=
chunk_size
self
.
left_chunks
=
left_chunks
else
:
self
.
input_dim
=
args
.
transformer_input_dim
self
.
output_dim
=
args
.
transformer_output_dim
self
.
attention_dim
=
args
.
transformer_attention_dim
self
.
attention_heads
=
args
.
transformer_attention_heads
self
.
linear_units
=
args
.
transformer_linear_units
self
.
num_blocks
=
args
.
transformer_num_blocks
self
.
dropout_rate
=
args
.
transformer_dropout_rate
self
.
positional_dropout_rate
=
args
.
transformer_positional_dropout_rate
self
.
attention_dropout_rate
=
args
.
transformer_attention_dropout_rate
self
.
input_layer
=
args
.
transformer_input_layer
self
.
pos_enc_class
=
args
.
transformer_pos_enc_class
self
.
normalize_before
=
args
.
transformer_normalize_before
self
.
concat_after
=
args
.
transformer_concat_after
self
.
positionwise_layer_type
=
args
.
transformer_positionwise_layer_type
self
.
positionwise_conv_kernel_size
=
args
.
transformer_positionwise_conv_kernel_size
self
.
chunk_size
=
args
.
transformer_chunk_size
self
.
left_chunks
=
args
.
transformer_left_chunks
self
.
transformer_dynamic_chunks
=
args
.
transformer_dynamic_chunks
if
self
.
pos_enc_class
==
"abs-enc"
:
pos_enc_args
=
(
self
.
attention_dim
,
self
.
positional_dropout_rate
)
pos_enc_class
=
PositionalEncoding
elif
self
.
pos_enc_class
==
"rel-enc"
:
pos_enc_args
=
(
self
.
attention_dim
,
self
.
positional_dropout_rate
,
self
.
chunk_size
,
self
.
left_chunks
,
)
pos_enc_class
=
RelPositionalEncoding
if
self
.
input_layer
==
"linear"
:
self
.
embed
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
self
.
input_dim
,
self
.
attention_dim
),
torch
.
nn
.
LayerNorm
(
self
.
attention_dim
),
torch
.
nn
.
Dropout
(
self
.
dropout_rate
),
torch
.
nn
.
ReLU
(),
)
elif
self
.
input_layer
==
"embed"
:
self
.
embed
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Embedding
(
self
.
input_dim
,
self
.
attention_dim
,
padding_idx
=
IGNORE_ID
)
)
elif
self
.
input_layer
==
"none"
:
self
.
embed
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Identity
())
else
:
raise
ValueError
(
"unknown input_layer: "
+
self
.
input_layer
)
self
.
pe
=
pos_enc_class
(
*
pos_enc_args
)
self
.
embed_layer_num
=
len
(
self
.
embed
)
if
self
.
positionwise_layer_type
==
"linear"
:
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
self
.
attention_dim
,
self
.
linear_units
,
self
.
dropout_rate
)
elif
self
.
positionwise_layer_type
==
"conv1d"
:
positionwise_layer
=
MultiLayeredConv1d
positionwise_layer_args
=
(
self
.
attention_dim
,
self
.
linear_units
,
self
.
positionwise_conv_kernel_size
,
self
.
dropout_rate
,
)
elif
self
.
positionwise_layer_type
==
"conv1d-linear"
:
positionwise_layer
=
Conv1dLinear
positionwise_layer_args
=
(
self
.
attention_dim
,
self
.
linear_units
,
self
.
positionwise_conv_kernel_size
,
self
.
dropout_rate
,
)
else
:
raise
NotImplementedError
(
"Support only linear or conv1d."
)
self
.
encoders
=
repeat
(
self
.
num_blocks
,
lambda
lnum
:
TransformerLayer
(
self
.
attention_dim
,
MultiHeadedAttention
(
self
.
attention_heads
,
self
.
attention_dim
,
self
.
attention_dropout_rate
,
self
.
chunk_size
,
self
.
left_chunks
,
self
.
pos_enc_class
,
),
positionwise_layer
(
*
positionwise_layer_args
),
self
.
dropout_rate
,
self
.
normalize_before
,
self
.
concat_after
,
),
)
if
self
.
normalize_before
:
self
.
after_norm
=
torch
.
nn
.
LayerNorm
(
self
.
attention_dim
)
@
torch
.
jit
.
unused
def
forward
(
self
,
xs
,
ilens
=
None
,
masks
=
None
):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if
self
.
transformer_dynamic_chunks
==
True
:
# and self.training:
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
True
,
True
,
0
,
0
,
-
1
)
else
:
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
False
,
False
,
self
.
chunk_size
,
self
.
chunk_size
,
self
.
left_chunks
).
to
(
xs
.
device
)
xs
=
self
.
embed
(
xs
)
xs
,
pos_emb
=
self
.
pe
(
xs
)
xs
,
chunk_masks
,
pos_emb
=
self
.
encoders
(
xs
,
chunk_masks
,
pos_emb
)
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
ilens
,
masks
@
torch
.
jit
.
export
def
infer
(
self
,
xs
,
buffer
,
buffer_index
,
buffer_out
):
xs
=
self
.
embed
(
xs
)
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
# buffer_index = buffer_index + 1
xs
,
pos_emb
,
_
=
self
.
pe
.
infer
(
xs
,
0
)
xs
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
=
self
.
encoders
.
infer
(
xs
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
)
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
buffer
,
buffer_index
,
buffer_out
@
torch
.
jit
.
export
def
infer_hidden
(
self
,
xs
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
):
xs
=
self
.
embed
(
xs
)
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
# buffer_index = buffer_index + 1
xs
,
pos_emb
,
_
=
self
.
pe
.
infer
(
xs
,
0
)
xs
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
=
self
.
encoders
.
infer_hidden
(
xs
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
)
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
return
xs
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
VITA/model/multimodal_encoder/whale/module/encoder/__pycache__/encoder.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/module/encoder/encoder.py
0 → 100644
View file @
112bf76b
import
argparse
import
logging
import
sys
import
time
from
typing
import
Dict
,
Optional
,
Tuple
import
numpy
as
np
import
six
import
torch
from
vita.model.multimodal_encoder.whale.module.component.mamba
import
MambaSSM
from
vita.model.multimodal_encoder.whale.module.component.subsampling
import
Subsampling
from
vita.model.multimodal_encoder.whale.module.component.transformer
import
Transformer
from
vita.model.multimodal_encoder.whale.utils
import
make_pad_mask
def
add_encoder_args
(
group
):
"""Add Encoder common arguments."""
group
.
add_argument
(
"--encoder-layer-config"
,
type
=
str
,
default
=
"tdnn-dtc"
,
help
=
"Layer config of encoder. Format layername-layername-..., default(conv1d-fsmn-rnn)"
,
)
group
.
add_argument
(
"--encoder-input-dim"
,
type
=
int
,
default
=
256
,
help
=
"Input dim of encoder. Must equal to the input dim of the first Component (default=40)"
,
)
group
.
add_argument
(
"--encoder-output-dim"
,
type
=
int
,
default
=
256
,
help
=
"Output dim of encoder. Must enqual to the output dim of the last Component ! (default=256)"
,
)
# Add args of all kinds of components.
# If you add a new component, DO NOT forget to add args to add_component_args func.
group
=
Transformer
.
add_arguments
(
group
)
group
=
Subsampling
.
add_arguments
(
group
)
group
=
MambaSSM
.
add_arguments
(
group
)
return
group
def
assign_args_from_dict
(
args
,
dict
,
prefix_key
=
None
):
if
prefix_key
is
not
None
:
dict
=
dict
[
prefix_key
]
for
k
,
v
in
dict
.
items
():
k_args
=
k
.
replace
(
"-"
,
"_"
)
if
hasattr
(
args
,
k_args
):
setattr
(
args
,
k_args
,
dict
[
k
])
return
args
class
whaleEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
overview_conf
=
None
,
para_conf
=
None
,
global_cmvn
=
None
):
super
(
whaleEncoder
,
self
).
__init__
()
parser
=
argparse
.
ArgumentParser
()
add_encoder_args
(
parser
)
args
,
_
=
parser
.
parse_known_args
()
assign_args_from_dict
(
args
,
overview_conf
)
# assign_args_from_dict(args, para_conf)
self
.
config
=
args
.
encoder_layer_config
.
split
(
"-"
)
encoder_input_dim
=
args
.
encoder_input_dim
encoder_output_dim
=
args
.
encoder_output_dim
prev_output_dim
=
encoder_input_dim
prev_component_name
=
"encoder"
self
.
enc
=
torch
.
nn
.
ModuleList
([])
for
name
in
self
.
config
:
assign_args_from_dict
(
args
,
para_conf
[
name
])
if
len
(
name
.
split
(
"_"
))
==
2
:
name
=
name
.
split
(
"_"
)[
0
]
elif
len
(
name
.
split
(
"_"
))
==
1
:
name
=
name
else
:
logging
.
error
(
"WRONG CONFIG! {} is not valid"
.
format
(
"encoder"
,
name
))
sys
.
exit
()
if
name
==
"transformer"
:
self
.
enc
.
append
(
Transformer
(
args
))
elif
name
==
"subsampling"
:
self
.
enc
.
append
(
Subsampling
(
args
))
elif
name
==
"mamba"
:
self
.
enc
.
append
(
MambaSSM
(
args
))
else
:
print
(
"{} is not supported now!"
.
format
(
name
))
return
NotImplemented
component_input_dim
=
getattr
(
args
,
name
+
"_input_dim"
)
if
component_input_dim
!=
prev_output_dim
:
# This is the first layer
logging
.
error
(
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-input-dim ({})"
.
format
(
prev_component_name
,
prev_output_dim
,
name
,
component_input_dim
)
)
sys
.
exit
()
prev_output_dim
=
getattr
(
args
,
name
+
"_output_dim"
)
prev_component_name
=
name
self
.
global_cmvn
=
global_cmvn
if
prev_output_dim
!=
encoder_output_dim
:
logging
.
error
(
"WRONG CONFIG! --{}-output-dim ({}) does not equal to --{}-output-dim ({}, the last component)"
.
format
(
"encoder"
,
encoder_output_dim
,
name
,
prev_output_dim
)
)
sys
.
exit
()
self
.
_output_size
=
encoder_output_dim
num_params
=
sum
(
p
.
numel
()
for
p
in
self
.
parameters
())
print
(
"the number of whale encoder params: {}M"
.
format
(
num_params
/
1024
/
1024
))
def
output_size
(
self
)
->
int
:
return
self
.
_output_size
@
torch
.
jit
.
unused
def
forward
(
self
,
xs
,
ilens
,
decoding_chunk_size
=
None
,
num_decoding_left_chunks
=
None
):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Optional[List[int]], Optional[Tensor]]
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if
decoding_chunk_size
is
not
None
and
num_decoding_left_chunks
is
not
None
:
for
layer
in
self
.
enc
:
if
hasattr
(
layer
,
"chunk_size"
):
layer
.
chunk_size
=
decoding_chunk_size
if
hasattr
(
layer
,
"left_chunks"
):
layer
.
left_chunks
=
num_decoding_left_chunks
if
hasattr
(
layer
,
"transformer_dynamic_chunks"
):
layer
.
transformer_dynamic_chunks
=
False
assert
(
len
(
xs
.
shape
))
==
3
T
=
xs
.
size
(
1
)
masks
=
~
make_pad_mask
(
ilens
,
T
).
unsqueeze
(
1
)
# (B, 1, T)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
for
module
in
self
.
enc
:
xs
,
ilens
,
masks
=
module
(
xs
,
ilens
,
masks
)
return
xs
,
masks
@
torch
.
jit
.
export
def
infer
(
self
,
xs_pad
,
buffer
,
buffer_index
,
buffer_out
):
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
for
module
in
self
.
enc
:
xs_pad
,
buffer
,
buffer_index
,
buffer_out
=
module
.
infer
(
xs_pad
,
buffer
,
buffer_index
,
buffer_out
)
return
xs_pad
,
buffer
,
buffer_index
,
buffer_out
@
torch
.
jit
.
export
def
infer_hidden
(
self
,
xs_pad
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
):
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
for
module
in
self
.
enc
:
xs_pad
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
=
module
.
infer_hidden
(
xs_pad
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
)
return
xs_pad
,
buffer
,
buffer_index
,
buffer_out
,
hidden_out
@
torch
.
jit
.
ignore
(
drop
=
True
)
def
get_extra_loss
(
self
)
->
Dict
[
str
,
torch
.
Tensor
]:
return
None
VITA/model/multimodal_encoder/whale/module/layer/__pycache__/attention.cpython-310.pyc
0 → 100644
View file @
112bf76b
File added
VITA/model/multimodal_encoder/whale/module/layer/attention.py
0 → 100644
View file @
112bf76b
import
math
import
pdb
import
numpy
import
torch
import
torch.nn
as
nn
class
PositionalEncoding
(
torch
.
nn
.
Module
):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
1500
,
reverse
:
bool
=
False
):
"""Construct an PositionalEncoding object."""
super
().
__init__
()
self
.
d_model
=
d_model
self
.
xscale
=
math
.
sqrt
(
self
.
d_model
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
max_len
=
max_len
self
.
pe
=
torch
.
zeros
(
self
.
max_len
,
self
.
d_model
)
position
=
torch
.
arange
(
0
,
self
.
max_len
,
dtype
=
torch
.
float32
).
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
torch
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
)
)
self
.
pe
[:,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
self
.
pe
[:,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
self
.
pe
=
self
.
pe
.
unsqueeze
(
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
offset
:
int
=
0
):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
assert
offset
+
x
.
size
(
1
)
<
self
.
max_len
self
.
pe
=
self
.
pe
.
to
(
x
.
device
)
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
size
(
1
)]
x
=
x
*
self
.
xscale
+
pos_emb
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
def
position_encoding
(
self
,
offset
:
int
,
size
:
int
):
"""For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
assert
offset
+
size
<
self
.
max_len
return
self
.
dropout
(
self
.
pe
[:,
offset
:
offset
+
size
])
class
RelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
chunk_size
:
int
,
left_chunks
:
int
,
max_len
:
int
=
5000
,
):
"""Initialize class."""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
self
.
chunk_size
=
chunk_size
self
.
left_chunks
=
left_chunks
self
.
full_chunk_size
=
(
self
.
left_chunks
+
1
)
*
self
.
chunk_size
self
.
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
torch
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
)
)
self
.
max_len
=
self
.
chunk_size
*
(
max_len
//
self
.
chunk_size
)
-
self
.
full_chunk_size
@
torch
.
jit
.
export
def
forward
(
self
,
x
:
torch
.
Tensor
,
offset
:
int
=
0
):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
pe
=
self
.
pe
.
to
(
x
.
device
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
x
.
size
(
1
)]
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
@
torch
.
jit
.
export
def
infer
(
self
,
xs
,
pe_index
):
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
pe_index
=
pe_index
%
self
.
max_len
xs
=
xs
*
self
.
xscale
pe
=
torch
.
zeros
(
self
.
full_chunk_size
,
self
.
d_model
)
position
=
torch
.
arange
(
pe_index
,
pe_index
+
self
.
full_chunk_size
,
dtype
=
torch
.
float32
).
unsqueeze
(
1
)
pe
[:,
0
::
2
]
=
torch
.
sin
(
position
*
self
.
div_term
)
pe
[:,
1
::
2
]
=
torch
.
cos
(
position
*
self
.
div_term
)
pos_emb
=
pe
.
unsqueeze
(
0
)
pe_index
=
pe_index
+
self
.
chunk_size
return
xs
,
pos_emb
,
pe_index
class
PositionwiseFeedForward
(
torch
.
nn
.
Module
):
"""Positionwise feed forward layer.
:param int idim: input dimenstion
:param int hidden_units: number of hidden units
:param float dropout_rate: dropout rate
"""
def
__init__
(
self
,
idim
,
hidden_units
,
dropout_rate
):
"""Construct an PositionwiseFeedForward object."""
super
(
PositionwiseFeedForward
,
self
).
__init__
()
self
.
w_1
=
torch
.
nn
.
Linear
(
idim
,
hidden_units
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_units
,
idim
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
def
forward
(
self
,
x
):
"""Forward funciton."""
return
self
.
w_2
(
self
.
dropout
(
torch
.
relu
(
self
.
w_1
(
x
))))
@
torch
.
jit
.
export
def
infer
(
self
,
xs
,
buffer
,
buffer_index
,
buffer_out
):
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
return
self
.
w_2
(
torch
.
relu
(
self
.
w_1
(
xs
))),
buffer
,
buffer_index
,
buffer_out
class
MultiLayeredConv1d
(
torch
.
nn
.
Module
):
"""Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed
to replace positionwise feed-forward network
in Transformer block, which is introduced in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def
__init__
(
self
,
in_chans
,
hidden_chans
,
kernel_size
,
dropout_rate
):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super
(
MultiLayeredConv1d
,
self
).
__init__
()
self
.
w_1
=
torch
.
nn
.
Conv1d
(
in_chans
,
hidden_chans
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
self
.
w_2
=
torch
.
nn
.
Conv1d
(
hidden_chans
,
in_chans
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
@
torch
.
jit
.
unused
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x
=
torch
.
relu
(
self
.
w_1
(
x
.
transpose
(
-
1
,
1
))).
transpose
(
-
1
,
1
)
return
self
.
w_2
(
self
.
dropout
(
x
).
transpose
(
-
1
,
1
)).
transpose
(
-
1
,
1
)
class
Conv1dLinear
(
torch
.
nn
.
Module
):
"""Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
"""
def
__init__
(
self
,
in_chans
,
hidden_chans
,
kernel_size
,
dropout_rate
):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super
(
Conv1dLinear
,
self
).
__init__
()
self
.
lorder
=
kernel_size
-
1
self
.
left_padding
=
nn
.
ConstantPad1d
((
self
.
lorder
,
0
),
0.0
)
self
.
w_1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
in_chans
,
in_chans
,
kernel_size
,
stride
=
1
,
padding
=
0
,
groups
=
in_chans
),
torch
.
nn
.
Conv1d
(
in_chans
,
hidden_chans
,
1
,
padding
=
0
),
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_chans
,
in_chans
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
self
.
in_chans
=
in_chans
# cnn_buffer = 1, in_chans, self.lorder
self
.
buffer_size
=
1
*
self
.
in_chans
*
self
.
lorder
@
torch
.
jit
.
unused
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x
=
torch
.
relu
(
self
.
w_1
(
self
.
left_padding
(
x
.
transpose
(
-
1
,
1
)))).
transpose
(
-
1
,
1
)
return
self
.
w_2
(
self
.
dropout
(
x
))
@
torch
.
jit
.
export
def
infer
(
self
,
x
,
buffer
,
buffer_index
,
buffer_out
):
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
x
=
x
.
transpose
(
-
1
,
1
)
cnn_buffer
=
buffer
[
buffer_index
:
buffer_index
+
self
.
buffer_size
].
reshape
(
[
1
,
self
.
in_chans
,
self
.
lorder
]
)
x
=
torch
.
cat
([
cnn_buffer
,
x
],
dim
=
2
)
buffer_out
.
append
(
x
[:,
:,
-
self
.
lorder
:].
reshape
(
-
1
))
buffer_index
=
buffer_index
+
self
.
buffer_size
x
=
self
.
w_1
(
x
)
x
=
torch
.
relu
(
x
).
transpose
(
-
1
,
1
)
x
=
self
.
w_2
(
x
)
return
x
,
buffer
,
buffer_index
,
buffer_out
class
MultiHeadedAttention
(
nn
.
Module
):
"""Multi-Head Attention layer.
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
chunk_size
,
left_chunks
,
pos_enc_class
):
"""Construct an MultiHeadedAttention object."""
super
(
MultiHeadedAttention
,
self
).
__init__
()
assert
n_feat
%
n_head
==
0
# We assume d_v always equals d_k
self
.
d_k
=
n_feat
//
n_head
self
.
h
=
n_head
self
.
linear_q
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_k
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_v
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_out
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
# self.min_value = float(numpy.finfo(torch.tensor(0, dtype=torch.float16).numpy().dtype).min)
self
.
min_value
=
float
(
torch
.
finfo
(
torch
.
float16
).
min
)
# chunk par
if
chunk_size
>
0
and
left_chunks
>
0
:
# for streaming mode
self
.
buffersize
=
chunk_size
*
(
left_chunks
)
self
.
left_chunk_size
=
chunk_size
*
left_chunks
else
:
# for non-streaming mode
self
.
buffersize
=
1
self
.
left_chunk_size
=
1
self
.
chunk_size
=
chunk_size
# encoding setup
if
pos_enc_class
==
"rel-enc"
:
self
.
rel_enc
=
True
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
self
.
pos_bias_v
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_u
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_v
)
else
:
self
.
rel_enc
=
False
self
.
linear_pos
=
nn
.
Identity
()
self
.
pos_bias_u
=
torch
.
tensor
([
0
])
self
.
pos_bias_v
=
torch
.
tensor
([
0
])
# buffer
# key_buffer = 1, self.h, self.buffersize, self.d_k
self
.
key_buffer_size
=
1
*
self
.
h
*
self
.
buffersize
*
self
.
d_k
# value_buffer = 1, self.h, self.buffersize, self.d_k
self
.
value_buffer_size
=
1
*
self
.
h
*
self
.
buffersize
*
self
.
d_k
if
self
.
chunk_size
>
0
:
# buffer_mask_size = 1, self.h, self.chunk_size, self.buffersize
self
.
buffer_mask_size
=
1
*
self
.
h
*
self
.
chunk_size
*
self
.
buffersize
# self.buffer_mask = torch.ones([1, self.h, self.chunk_size, self.buffersize], dtype=torch.bool)
else
:
self
.
buffer_mask
=
torch
.
ones
([
1
,
self
.
h
,
1
,
1
],
dtype
=
torch
.
bool
)
@
torch
.
jit
.
unused
def
rel_shift
(
self
,
x
,
zero_triu
:
bool
=
False
):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad
=
torch
.
zeros
(
(
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
()[
2
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
x_padded
.
view
(
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
(
3
)
+
1
,
x
.
size
(
2
))
x
=
x_padded
[:,
:,
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
2
),
x
.
size
(
3
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
3
)
-
x
.
size
(
2
))[
None
,
None
,
:,
:]
return
x
@
torch
.
jit
.
export
def
forward
(
self
,
query
,
key
,
value
,
mask
=
None
,
pos_emb
=
torch
.
tensor
(
1.0
)):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tensor) -> Tensor
"""Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
"""
n_batch
=
query
.
size
(
0
)
q
=
self
.
linear_q
(
query
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
k
=
self
.
linear_k
(
key
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
v
=
self
.
linear_v
(
value
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
q
=
q
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
k
=
k
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
v
=
v
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
if
self
.
rel_enc
:
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
n_batch_pos
=
pos_emb
.
size
(
0
)
p
=
self
.
linear_pos
(
pos_emb
.
to
(
query
.
dtype
)).
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
).
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
).
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
else
:
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
if
mask
is
not
None
:
mask
=
mask
.
unsqueeze
(
1
).
eq
(
0
)
# (batch, 1, time1, time2)
scores
=
scores
.
masked_fill
(
mask
,
self
.
min_value
)
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
).
masked_fill
(
mask
,
0.0
)
# (batch, head, time1, time2)
else
:
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# (batch, head, time1, time2)
p_attn
=
self
.
dropout
(
attn
)
x
=
torch
.
matmul
(
p_attn
,
v
)
# (batch, head, time1, d_k)
x
=
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
n_batch
,
-
1
,
self
.
h
*
self
.
d_k
)
)
# (batch, time1, d_model)
return
self
.
linear_out
(
x
)
# (batch, time1, d_model)
@
torch
.
jit
.
export
def
infer
(
self
,
query
,
key
,
value
,
pos_emb
,
buffer
,
buffer_index
,
buffer_out
):
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
n_batch
=
query
.
size
(
0
)
q
=
(
self
.
linear_q
(
query
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
)
# (batch, head, len_q, d_k)
k
=
(
self
.
linear_k
(
key
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
)
# (batch, head, len_k, d_k)
v
=
(
self
.
linear_v
(
value
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
)
# (batch, head, len_v, d_k)
key_value_buffer
=
buffer
[
buffer_index
:
buffer_index
+
self
.
key_buffer_size
+
self
.
value_buffer_size
].
reshape
([
1
,
self
.
h
,
self
.
buffersize
*
2
,
self
.
d_k
])
key_buffer
=
torch
.
cat
([
key_value_buffer
[:,
:,
:
self
.
buffersize
,
:],
k
],
dim
=
2
)
value_buffer
=
torch
.
cat
([
key_value_buffer
[:,
:,
self
.
buffersize
:,
:],
v
],
dim
=
2
)
buffer_out
.
append
(
torch
.
cat
(
[
key_buffer
[:,
:,
self
.
chunk_size
:,
:],
value_buffer
[:,
:,
self
.
chunk_size
:,
:]],
dim
=
2
,
).
reshape
(
-
1
)
)
buffer_index
=
buffer_index
+
self
.
key_buffer_size
+
self
.
value_buffer_size
if
self
.
rel_enc
:
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
n_batch_pos
=
pos_emb
.
size
(
0
)
p
=
self
.
linear_pos
(
pos_emb
).
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
).
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
).
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
key_buffer
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
else
:
scores
=
torch
.
matmul
(
q
,
key_buffer
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, len_q, buffersize)
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
x
=
torch
.
matmul
(
attn
,
value_buffer
)
# (batch, head, len_q, d_k)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
n_batch
,
-
1
,
self
.
h
*
self
.
d_k
)
# (batch, time1, d_model)
return
self
.
linear_out
(
x
),
buffer
,
buffer_index
,
buffer_out
# (batch, time1, d_model)
@
torch
.
jit
.
export
def
infer_mask
(
self
,
query
,
key
,
value
,
mask
,
buffer
,
buffer_index
,
buffer_out
,
is_static
):
n_batch
=
query
.
size
(
0
)
q
=
(
self
.
linear_q
(
query
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
)
# (batch, head, len_q, d_k)
k
=
(
self
.
linear_k
(
key
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
)
# (batch, head, len_k, d_k)
v
=
(
self
.
linear_v
(
value
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
)
# (batch, head, len_v, d_k)
if
is_static
:
key_buffer
=
k
value_buffer
=
v
else
:
key_value_buffer
=
buffer
[
buffer_index
:
buffer_index
+
self
.
key_buffer_size
+
self
.
value_buffer_size
].
reshape
([
1
,
self
.
h
,
self
.
buffersize
*
2
,
self
.
d_k
])
key_buffer
=
torch
.
cat
([
key_value_buffer
[:,
:,
:
self
.
buffersize
,
:],
k
],
dim
=
2
)
value_buffer
=
torch
.
cat
([
key_value_buffer
[:,
:,
self
.
buffersize
:,
:],
v
],
dim
=
2
)
buffer_out
.
append
(
torch
.
cat
(
[
key_buffer
[:,
:,
self
.
chunk_size
:,
:],
value_buffer
[:,
:,
self
.
chunk_size
:,
:],
],
dim
=
2
,
).
reshape
(
-
1
)
)
buffer_index
=
buffer_index
+
self
.
key_buffer_size
+
self
.
value_buffer_size
scores
=
torch
.
matmul
(
q
,
key_buffer
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, len_q, buffersize)
if
mask
is
not
None
:
mask
=
mask
.
unsqueeze
(
1
).
eq
(
0
)
# (batch, 1, time1, time2)
scores
=
scores
.
masked_fill
(
mask
,
self
.
min_value
)
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
).
masked_fill
(
mask
,
0.0
)
# (batch, head, time1, time2)
else
:
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# (batch, head, time1, time2)
x
=
torch
.
matmul
(
attn
,
value_buffer
)
# (batch, head, len_q, d_k)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
n_batch
,
-
1
,
self
.
h
*
self
.
d_k
)
# (batch, time1, d_model)
return
self
.
linear_out
(
x
),
buffer_index
,
buffer_out
# (batch, time1, d_model)
class
SoftAttention
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
hidden_dim
):
super
(
SoftAttention
,
self
).
__init__
()
self
.
q
=
torch
.
nn
.
Parameter
(
torch
.
rand
([
hidden_dim
]),
requires_grad
=
True
)
self
.
wb
=
nn
.
Linear
(
in_dim
,
hidden_dim
)
self
.
min_value
=
float
(
numpy
.
finfo
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
).
numpy
().
dtype
).
min
)
# buffer
self
.
window_size
=
50
self
.
buffer_in
=
torch
.
zeros
([
1
,
self
.
window_size
,
in_dim
],
dtype
=
torch
.
float32
)
self
.
buffer
=
torch
.
zeros
([
1
,
self
.
window_size
],
dtype
=
torch
.
float32
)
self
.
buffer
[:,
:]
=
float
(
numpy
.
finfo
(
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
).
numpy
().
dtype
).
min
)
@
torch
.
jit
.
unused
def
forward
(
self
,
x
,
mask
=
None
):
hidden
=
torch
.
tanh
(
self
.
wb
(
x
))
# B T D
hidden
=
torch
.
einsum
(
"btd,d->bt"
,
hidden
,
self
.
q
)
score
=
torch
.
softmax
(
hidden
,
dim
=-
1
)
# B T
if
mask
is
not
None
:
score
=
score
.
masked_fill
(
mask
,
0.0
)
output
=
torch
.
einsum
(
"bt,btd->bd"
,
score
,
x
)
return
output
@
torch
.
jit
.
export
def
infer
(
self
,
x
):
# type: (Tensor) -> Tensor
hidden
=
torch
.
tanh
(
self
.
wb
(
x
))
# B T D
hidden
=
torch
.
einsum
(
"btd,d->bt"
,
hidden
,
self
.
q
)
size
=
hidden
.
shape
[
1
]
output
=
torch
.
zeros
([
size
,
x
.
shape
[
-
1
]])
for
i
in
range
(
size
):
self
.
buffer
=
torch
.
cat
([
self
.
buffer
,
hidden
[:,
i
:
i
+
1
]],
dim
=-
1
)
self
.
buffer
=
self
.
buffer
[:,
1
:]
score
=
torch
.
softmax
(
self
.
buffer
,
dim
=-
1
)
# B T
self
.
buffer_in
=
torch
.
cat
([
self
.
buffer_in
,
x
[:,
i
:
i
+
1
,
:]],
dim
=
1
)
self
.
buffer_in
=
self
.
buffer_in
[:,
1
:]
output
[
i
:
i
+
1
]
=
torch
.
einsum
(
"bt,btd->bd"
,
score
,
self
.
buffer_in
)
return
output
VITA/model/multimodal_encoder/whale/module/layer/conv1d.py
0 → 100644
View file @
112bf76b
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
Conv1dLayer
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
output_dim
,
kernel_size
,
stride
,
causal_conv
,
dilation
,
dropout_rate
,
residual
=
True
,
):
super
(
Conv1dLayer
,
self
).
__init__
()
self
.
input_dim
=
input_dim
self
.
output_dim
=
output_dim
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
causal_conv
=
causal_conv
if
causal_conv
:
self
.
lorder
=
(
kernel_size
-
1
)
*
self
.
dilation
self
.
left_padding
=
nn
.
ConstantPad1d
((
self
.
lorder
,
0
),
0.0
)
else
:
assert
(
kernel_size
-
1
)
%
2
==
0
self
.
lorder
=
((
kernel_size
-
1
)
//
2
)
*
self
.
dilation
self
.
left_padding
=
nn
.
ConstantPad1d
((
self
.
lorder
,
self
.
lorder
),
0.0
)
self
.
conv1d
=
nn
.
Conv1d
(
self
.
input_dim
,
self
.
output_dim
,
self
.
kernel_size
,
self
.
stride
,
0
,
self
.
dilation
)
self
.
bn
=
nn
.
BatchNorm1d
(
self
.
output_dim
,
eps
=
1e-3
,
momentum
=
0.99
)
self
.
relu
=
nn
.
ReLU
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
residual
=
residual
if
self
.
input_dim
!=
self
.
output_dim
:
self
.
residual
=
False
# buffer = 1, self.input_dim, self.lorder
self
.
lorder
=
(
kernel_size
-
1
)
*
self
.
dilation
-
(
self
.
stride
-
1
)
self
.
buffer_size
=
1
*
self
.
input_dim
*
self
.
lorder
self
.
x_data_chache_size
=
self
.
lorder
self
.
x_data_buffer_size
=
self
.
input_dim
*
self
.
x_data_chache_size
@
torch
.
jit
.
unused
def
forward
(
self
,
x
):
x_data
=
x
x
=
self
.
left_padding
(
x
)
x
=
self
.
conv1d
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
stride
==
1
and
self
.
residual
:
x
=
self
.
relu
(
x
+
x_data
)
else
:
x
=
self
.
relu
(
x
)
x
=
self
.
dropout
(
x
)
return
x
@
torch
.
jit
.
export
def
infer
(
self
,
x
,
buffer
,
buffer_index
,
buffer_out
):
# type: (Tensor) -> Tensor
x_data
=
x
.
clone
()
cnn_buffer
=
buffer
[
buffer_index
:
buffer_index
+
self
.
buffer_size
].
reshape
(
[
1
,
self
.
input_dim
,
self
.
lorder
]
)
x
=
torch
.
cat
([
cnn_buffer
,
x
],
dim
=
2
)
buffer_out
.
append
(
x
[:,
:,
-
self
.
lorder
:].
reshape
(
-
1
))
buffer_index
=
buffer_index
+
self
.
buffer_size
x
=
self
.
conv1d
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
stride
==
1
and
self
.
residual
:
x_data_cnn_buffer
=
buffer
[
buffer_index
:
buffer_index
+
self
.
x_data_buffer_size
].
reshape
([
1
,
self
.
input_dim
,
self
.
x_data_chache_size
])
x_data
=
torch
.
cat
([
x_data_cnn_buffer
,
x_data
],
dim
=
2
)
buffer_out
.
append
(
x_data
[:,
:,
-
self
.
x_data_chache_size
:].
reshape
(
-
1
))
buffer_index
=
buffer_index
+
self
.
x_data_buffer_size
x_data
=
x_data
[:,
:,
:
-
self
.
x_data_chache_size
]
x
=
self
.
relu
(
x
+
x_data
)
else
:
x
=
self
.
relu
(
x
)
return
x
,
buffer
,
buffer_index
,
buffer_out
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment