Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SpeechT5_pytorch
Commits
12c90639
Commit
12c90639
authored
Sep 28, 2024
by
“change”
Browse files
init
parent
417b607b
Changes
350
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3651 additions
and
0 deletions
+3651
-0
SpeechLM/speechlm/models/fasttext2unit.py
SpeechLM/speechlm/models/fasttext2unit.py
+226
-0
SpeechLM/speechlm/models/speechlm.py
SpeechLM/speechlm/models/speechlm.py
+720
-0
SpeechLM/speechlm/models/speechlm_ctcasr.py
SpeechLM/speechlm/models/speechlm_ctcasr.py
+56
-0
SpeechLM/speechlm/models/speechlm_st.py
SpeechLM/speechlm/models/speechlm_st.py
+268
-0
SpeechLM/speechlm/modules/__init__.py
SpeechLM/speechlm/modules/__init__.py
+23
-0
SpeechLM/speechlm/modules/learned_positional_embedding.py
SpeechLM/speechlm/modules/learned_positional_embedding.py
+68
-0
SpeechLM/speechlm/modules/multihead_attention.py
SpeechLM/speechlm/modules/multihead_attention.py
+348
-0
SpeechLM/speechlm/modules/relative_pos_enc.py
SpeechLM/speechlm/modules/relative_pos_enc.py
+35
-0
SpeechLM/speechlm/modules/transformer_decoder.py
SpeechLM/speechlm/modules/transformer_decoder.py
+544
-0
SpeechLM/speechlm/modules/transformer_encoder.py
SpeechLM/speechlm/modules/transformer_encoder.py
+403
-0
SpeechLM/speechlm/modules/transformer_layer.py
SpeechLM/speechlm/modules/transformer_layer.py
+329
-0
SpeechLM/speechlm/modules/w2v_encoder.py
SpeechLM/speechlm/modules/w2v_encoder.py
+283
-0
SpeechLM/speechlm/scripts/pretrain_speechlm/base_speechlmh.sh
...chLM/speechlm/scripts/pretrain_speechlm/base_speechlmh.sh
+43
-0
SpeechLM/speechlm/scripts/pretrain_speechlm/base_speechlmp.sh
...chLM/speechlm/scripts/pretrain_speechlm/base_speechlmp.sh
+43
-0
SpeechLM/speechlm/scripts/pretrain_speechlm/large_speechlmp.sh
...hLM/speechlm/scripts/pretrain_speechlm/large_speechlmp.sh
+44
-0
SpeechLM/speechlm/scripts/tokenizer_fastT2U/generate.sh
SpeechLM/speechlm/scripts/tokenizer_fastT2U/generate.sh
+42
-0
SpeechLM/speechlm/scripts/tokenizer_fastT2U/infer.sh
SpeechLM/speechlm/scripts/tokenizer_fastT2U/infer.sh
+41
-0
SpeechLM/speechlm/scripts/tokenizer_fastT2U/train_s_5e-4.sh
SpeechLM/speechlm/scripts/tokenizer_fastT2U/train_s_5e-4.sh
+39
-0
SpeechLM/speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh
...M/speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh
+48
-0
SpeechLM/speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh
.../speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh
+48
-0
No files found.
Too many changes to show.
To preserve performance only
350 of 350+
files are displayed.
Plain diff
Email patch
SpeechLM/speechlm/models/fasttext2unit.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
logging
import
torch
from
fairseq
import
utils
from
fairseq.models
import
(
FairseqEncoderModel
,
register_model
,
register_model_architecture
,
)
from
fairseq.models.text_to_speech
import
fastspeech2
logger
=
logging
.
getLogger
(
__name__
)
class
VarianceAdaptor
(
fastspeech2
.
VarianceAdaptor
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
use_pitch
=
args
.
use_pitch
self
.
use_energe
=
args
.
use_energe
def
forward
(
self
,
x
,
padding_mask
,
durations
=
None
,
pitches
=
None
,
energies
=
None
,
d_factor
=
1.0
,
p_factor
=
1.0
,
e_factor
=
1.0
,
):
# x: B x T x C
log_dur_out
=
self
.
duration_predictor
(
x
)
dur_out
=
torch
.
clamp
(
torch
.
round
((
torch
.
exp
(
log_dur_out
)
-
1
)
*
d_factor
).
long
(),
min
=
0
)
dur_out
.
masked_fill_
(
padding_mask
,
0
)
if
self
.
use_pitch
:
pitch_out
,
pitch_emb
=
self
.
get_pitch_emb
(
x
,
pitches
,
p_factor
)
x
=
x
+
pitch_emb
else
:
pitch_out
=
None
if
self
.
use_energe
:
energy_out
,
energy_emb
=
self
.
get_energy_emb
(
x
,
energies
,
e_factor
)
x
=
x
+
energy_emb
else
:
energy_out
=
None
x
,
out_lens
=
self
.
length_regulator
(
x
,
dur_out
if
durations
is
None
else
durations
)
return
x
,
out_lens
,
log_dur_out
,
pitch_out
,
energy_out
class
FastSpeech2Encoder
(
fastspeech2
.
FastSpeech2Encoder
):
def
__init__
(
self
,
args
,
src_dict
,
embed_speaker
):
super
().
__init__
(
args
,
src_dict
,
embed_speaker
)
self
.
var_adaptor
=
VarianceAdaptor
(
args
)
self
.
apply
(
fastspeech2
.
model_init
)
@
register_model
(
"fasttext2unit"
)
class
FastText2UnitModel
(
FairseqEncoderModel
):
"""
Implementation for https://arxiv.org/abs/2006.04558
"""
NON_AUTOREGRESSIVE
=
True
@
staticmethod
def
add_args
(
parser
):
parser
.
add_argument
(
"--dropout"
,
type
=
float
)
parser
.
add_argument
(
"--output-frame-dim"
,
type
=
int
)
parser
.
add_argument
(
"--speaker-embed-dim"
,
type
=
int
)
# FFT blocks
parser
.
add_argument
(
"--fft-hidden-dim"
,
type
=
int
)
parser
.
add_argument
(
"--fft-kernel-size"
,
type
=
int
)
parser
.
add_argument
(
"--attention-dropout"
,
type
=
float
)
parser
.
add_argument
(
"--encoder-layers"
,
type
=
int
)
parser
.
add_argument
(
"--encoder-embed-dim"
,
type
=
int
)
parser
.
add_argument
(
"--encoder-attention-heads"
,
type
=
int
)
parser
.
add_argument
(
"--decoder-layers"
,
type
=
int
)
parser
.
add_argument
(
"--decoder-embed-dim"
,
type
=
int
)
parser
.
add_argument
(
"--decoder-attention-heads"
,
type
=
int
)
# variance predictor
parser
.
add_argument
(
"--var-pred-n-bins"
,
type
=
int
)
parser
.
add_argument
(
"--var-pred-hidden-dim"
,
type
=
int
)
parser
.
add_argument
(
"--var-pred-kernel-size"
,
type
=
int
)
parser
.
add_argument
(
"--var-pred-dropout"
,
type
=
float
)
# postnet
parser
.
add_argument
(
"--add-postnet"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--postnet-dropout"
,
type
=
float
)
parser
.
add_argument
(
"--postnet-layers"
,
type
=
int
)
parser
.
add_argument
(
"--postnet-conv-dim"
,
type
=
int
)
parser
.
add_argument
(
"--postnet-conv-kernel-size"
,
type
=
int
)
# pitch & energe
parser
.
add_argument
(
"--use-pitch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-energe"
,
action
=
"store_true"
)
def
__init__
(
self
,
encoder
,
args
,
src_dict
):
super
().
__init__
(
encoder
)
self
.
_num_updates
=
0
@
classmethod
def
build_model
(
cls
,
args
,
task
):
embed_speaker
=
task
.
get_speaker_embeddings
(
args
)
if
args
.
output_frame_dim
==
-
1
:
args
.
output_frame_dim
=
len
(
task
.
tgt_dict
)
encoder
=
FastSpeech2Encoder
(
args
,
task
.
src_dict
,
embed_speaker
)
return
cls
(
encoder
,
args
,
task
.
src_dict
)
def
set_num_updates
(
self
,
num_updates
):
super
().
set_num_updates
(
num_updates
)
self
.
_num_updates
=
num_updates
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
logits
=
net_output
[
0
]
if
log_probs
:
return
utils
.
log_softmax
(
logits
.
float
(),
dim
=-
1
)
else
:
return
utils
.
softmax
(
logits
.
float
(),
dim
=-
1
)
@
register_model_architecture
(
"fasttext2unit"
,
"fasttext2unit_s"
)
def
base_architecture
(
args
):
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
args
.
output_frame_dim
=
getattr
(
args
,
"output_frame_dim"
,
-
1
)
args
.
speaker_embed_dim
=
getattr
(
args
,
"speaker_embed_dim"
,
256
)
# FFT blocks
args
.
fft_hidden_dim
=
getattr
(
args
,
"fft_hidden_dim"
,
1024
)
args
.
fft_kernel_size
=
getattr
(
args
,
"fft_kernel_size"
,
9
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.0
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
4
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
2
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
4
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
256
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
2
)
# variance predictor
args
.
var_pred_n_bins
=
getattr
(
args
,
"var_pred_n_bins"
,
256
)
args
.
var_pred_hidden_dim
=
getattr
(
args
,
"var_pred_hidden_dim"
,
256
)
args
.
var_pred_kernel_size
=
getattr
(
args
,
"var_pred_kernel_size"
,
3
)
args
.
var_pred_dropout
=
getattr
(
args
,
"var_pred_dropout"
,
0.5
)
# postnet
args
.
add_postnet
=
getattr
(
args
,
"add_postnet"
,
False
)
args
.
postnet_dropout
=
getattr
(
args
,
"postnet_dropout"
,
0.5
)
args
.
postnet_layers
=
getattr
(
args
,
"postnet_layers"
,
5
)
args
.
postnet_conv_dim
=
getattr
(
args
,
"postnet_conv_dim"
,
512
)
args
.
postnet_conv_kernel_size
=
getattr
(
args
,
"postnet_conv_kernel_size"
,
5
)
# pitch & energe
args
.
use_pitch
=
getattr
(
args
,
"use_pitch"
,
False
)
args
.
use_energe
=
getattr
(
args
,
"use_energe"
,
False
)
@
register_model_architecture
(
"fasttext2unit"
,
"fasttext2unit_m"
)
def
base_architecture
(
args
):
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
args
.
output_frame_dim
=
getattr
(
args
,
"output_frame_dim"
,
-
1
)
args
.
speaker_embed_dim
=
getattr
(
args
,
"speaker_embed_dim"
,
256
)
# FFT blocks
args
.
fft_hidden_dim
=
getattr
(
args
,
"fft_hidden_dim"
,
1024
)
args
.
fft_kernel_size
=
getattr
(
args
,
"fft_kernel_size"
,
9
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.0
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
256
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
2
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
256
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
2
)
# variance predictor
args
.
var_pred_n_bins
=
getattr
(
args
,
"var_pred_n_bins"
,
256
)
args
.
var_pred_hidden_dim
=
getattr
(
args
,
"var_pred_hidden_dim"
,
256
)
args
.
var_pred_kernel_size
=
getattr
(
args
,
"var_pred_kernel_size"
,
3
)
args
.
var_pred_dropout
=
getattr
(
args
,
"var_pred_dropout"
,
0.5
)
# postnet
args
.
add_postnet
=
getattr
(
args
,
"add_postnet"
,
False
)
args
.
postnet_dropout
=
getattr
(
args
,
"postnet_dropout"
,
0.5
)
args
.
postnet_layers
=
getattr
(
args
,
"postnet_layers"
,
5
)
args
.
postnet_conv_dim
=
getattr
(
args
,
"postnet_conv_dim"
,
512
)
args
.
postnet_conv_kernel_size
=
getattr
(
args
,
"postnet_conv_kernel_size"
,
5
)
# pitch & energe
args
.
use_pitch
=
getattr
(
args
,
"use_pitch"
,
False
)
args
.
use_energe
=
getattr
(
args
,
"use_energe"
,
False
)
@
register_model_architecture
(
"fasttext2unit"
,
"fasttext2unit_l"
)
def
base_architecture
(
args
):
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.2
)
args
.
output_frame_dim
=
getattr
(
args
,
"output_frame_dim"
,
-
1
)
args
.
speaker_embed_dim
=
getattr
(
args
,
"speaker_embed_dim"
,
256
)
# FFT blocks
args
.
fft_hidden_dim
=
getattr
(
args
,
"fft_hidden_dim"
,
1536
)
args
.
fft_kernel_size
=
getattr
(
args
,
"fft_kernel_size"
,
9
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
384
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
6
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
384
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
6
)
# variance predictor
args
.
var_pred_n_bins
=
getattr
(
args
,
"var_pred_n_bins"
,
256
)
args
.
var_pred_hidden_dim
=
getattr
(
args
,
"var_pred_hidden_dim"
,
256
)
args
.
var_pred_kernel_size
=
getattr
(
args
,
"var_pred_kernel_size"
,
3
)
args
.
var_pred_dropout
=
getattr
(
args
,
"var_pred_dropout"
,
0.5
)
# postnet
args
.
add_postnet
=
getattr
(
args
,
"add_postnet"
,
False
)
args
.
postnet_dropout
=
getattr
(
args
,
"postnet_dropout"
,
0.5
)
args
.
postnet_layers
=
getattr
(
args
,
"postnet_layers"
,
5
)
args
.
postnet_conv_dim
=
getattr
(
args
,
"postnet_conv_dim"
,
512
)
args
.
postnet_conv_kernel_size
=
getattr
(
args
,
"postnet_conv_kernel_size"
,
5
)
# pitch & energe
args
.
use_pitch
=
getattr
(
args
,
"use_pitch"
,
False
)
args
.
use_energe
=
getattr
(
args
,
"use_energe"
,
False
)
SpeechLM/speechlm/models/speechlm.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
logging
from
dataclasses
import
dataclass
,
field
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
,
checkpoint_utils
from
fairseq.data.data_utils
import
compute_mask_indices
from
fairseq.data.dictionary
import
Dictionary
from
fairseq.dataclass
import
ChoiceEnum
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.transformer
import
Embedding
from
fairseq.file_io
import
PathManager
from
torch
import
Tensor
from
fairseq.models.wav2vec.wav2vec2
import
ConvFeatureExtractionModel
from
fairseq.modules
import
GradMultiply
,
LayerNorm
from
fairseq.tasks.hubert_pretraining
import
(
HubertPretrainingConfig
,
HubertPretrainingTask
,
)
from
fairseq.models.hubert
import
HubertConfig
from
fairseq.models.transformer
import
TransformerConfig
from
speechlm.modules.w2v_encoder
import
TransformerEncoder
from
speechlm.modules.transformer_encoder
import
TransformerEncoderBase
logger
=
logging
.
getLogger
(
__name__
)
EXTRACTOR_MODE_CHOICES
=
ChoiceEnum
([
"default"
,
"layer_norm"
])
MASKING_DISTRIBUTION_CHOICES
=
ChoiceEnum
([
"static"
,
"uniform"
,
"normal"
,
"poisson"
])
@
dataclass
class
SpeechlmConfig
(
HubertConfig
):
use_rel_pos_enc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to use relative positional encoding"
},
)
scaling_for_att
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"scaling for attention weights to prevent overflow issue (for large model)"
},
)
# unit encoder-decoder
text_transformer
:
TransformerConfig
=
TransformerConfig
()
add_unit_encoder
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"add unit encoder"
},
)
add_decoder
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"add decoder"
},
)
add_text_ctc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"add_text_ctc head"
},
)
text_ctc_conv_kernel
:
int
=
field
(
default
=
2
,
metadata
=
{
"help"
:
"text_ctc_conv kernel size"
},
)
mask_u2t
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"mask the unit input in unit-to-text task"
},
)
compute_mum
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"compute MLM loss in unit-to-text task"
},
)
# embedding mixing
mix_with_unit
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"mix with the unit embeddings"
},
)
use_pred_unit
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use the embeddings of predicted units"
},
)
l2_embedding
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"compute l2 loss between unit embedding and unit hidden state"
},
)
# Finetune related
encoder_dict_size
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"text encoder dictionary dimension"
},
)
decoder_dict_size
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"decoder dictionary dimension"
},
)
@
register_model
(
"speechlm"
,
dataclass
=
SpeechlmConfig
)
class
SpeechlmModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
SpeechlmConfig
,
task_cfg
:
HubertPretrainingConfig
,
dictionaries
:
List
[
Dictionary
],
unit_dictionary
:
Dictionary
=
None
,
text_tgt_dictionary
:
Dictionary
=
None
,
)
->
None
:
super
().
__init__
()
logger
.
info
(
f
"SpeechlmModel Config:
{
cfg
}
"
)
feature_enc_layers
=
eval
(
cfg
.
conv_feature_layers
)
# noqa
self
.
embed
=
feature_enc_layers
[
-
1
][
0
]
self
.
feature_extractor
=
ConvFeatureExtractionModel
(
conv_layers
=
feature_enc_layers
,
dropout
=
0.0
,
mode
=
cfg
.
extractor_mode
,
conv_bias
=
cfg
.
conv_bias
,
)
feature_ds_rate
=
np
.
prod
([
s
for
_
,
_
,
s
in
feature_enc_layers
])
self
.
feat2tar_ratio
=
cfg
.
label_rate
*
feature_ds_rate
/
task_cfg
.
sample_rate
self
.
post_extract_proj
=
(
nn
.
Linear
(
self
.
embed
,
cfg
.
encoder_embed_dim
)
if
self
.
embed
!=
cfg
.
encoder_embed_dim
else
None
)
self
.
mask_prob
=
cfg
.
mask_prob
self
.
mask_selection
=
cfg
.
mask_selection
self
.
mask_other
=
cfg
.
mask_other
self
.
mask_length
=
cfg
.
mask_length
self
.
no_mask_overlap
=
cfg
.
no_mask_overlap
self
.
mask_min_space
=
cfg
.
mask_min_space
self
.
mask_channel_prob
=
cfg
.
mask_channel_prob
self
.
mask_channel_selection
=
cfg
.
mask_channel_selection
self
.
mask_channel_other
=
cfg
.
mask_channel_other
self
.
mask_channel_length
=
cfg
.
mask_channel_length
self
.
no_mask_channel_overlap
=
cfg
.
no_mask_channel_overlap
self
.
mask_channel_min_space
=
cfg
.
mask_channel_min_space
self
.
dropout_input
=
nn
.
Dropout
(
cfg
.
dropout_input
)
self
.
dropout_features
=
nn
.
Dropout
(
cfg
.
dropout_features
)
self
.
feature_grad_mult
=
cfg
.
feature_grad_mult
self
.
logit_temp
=
cfg
.
logit_temp
self
.
skip_masked
=
cfg
.
skip_masked
self
.
skip_nomask
=
cfg
.
skip_nomask
final_dim
=
cfg
.
final_dim
if
cfg
.
final_dim
>
0
else
cfg
.
encoder_embed_dim
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
cfg
.
encoder_embed_dim
).
uniform_
()
)
self
.
encoder
=
TransformerEncoder
(
cfg
)
self
.
layer_norm
=
LayerNorm
(
self
.
embed
)
self
.
target_glu
=
None
if
cfg
.
target_glu
:
self
.
target_glu
=
nn
.
Sequential
(
nn
.
Linear
(
final_dim
,
final_dim
*
2
),
nn
.
GLU
()
)
self
.
final_dim
=
final_dim
assert
len
(
dictionaries
)
<=
2
,
f
"Only support <=2 kinds of targets, get
{
len
(
dictionaries
)
}
dictionaries"
if
len
(
dictionaries
)
==
1
:
dictionaries
=
[
dictionaries
[
0
],
dictionaries
[
0
]]
self
.
final_proj_list
=
nn
.
ModuleList
([
nn
.
Linear
(
cfg
.
encoder_embed_dim
,
final_dim
)
for
_
in
dictionaries
])
self
.
num_classes
=
[
len
(
d
)
for
d
in
dictionaries
]
self
.
label_embs_list
=
nn
.
ParameterList
([
nn
.
Parameter
(
torch
.
FloatTensor
(
n
,
final_dim
))
for
n
in
self
.
num_classes
])
for
i
in
range
(
len
(
self
.
num_classes
)):
nn
.
init
.
uniform_
(
self
.
label_embs_list
[
i
])
### build unit encoder:
self
.
mask_u2t
=
cfg
.
mask_u2t
self
.
compute_mum
=
cfg
.
compute_mum
self
.
add_text_ctc
=
cfg
.
add_text_ctc
self
.
text_ctc_conv_kernel
=
cfg
.
text_ctc_conv_kernel
self
.
padding_idx
=
unit_dictionary
.
pad
()
self
.
unit_mask_idx
=
unit_dictionary
.
index
(
"<mask>"
)
self
.
add_unit_encoder
=
cfg
.
add_unit_encoder
self
.
mix_with_unit
=
cfg
.
mix_with_unit
self
.
use_pred_unit
=
cfg
.
use_pred_unit
self
.
l2_embedding
=
cfg
.
l2_embedding
if
self
.
add_unit_encoder
:
assert
len
(
unit_dictionary
)
==
self
.
num_classes
[
0
],
f
"unit_dictionary:
{
len
(
unit_dictionary
)
}
, self.num_classes[0]:
{
self
.
num_classes
[
0
]
}
"
### build unit pre-net, and shared with hubert label_embs if needed (default: False)
self
.
unit_embed_tokens
=
self
.
build_embedding
(
unit_dictionary
,
cfg
.
text_transformer
.
encoder
.
embed_dim
,
)
if
self
.
final_dim
==
cfg
.
text_transformer
.
encoder
.
embed_dim
:
logger
.
info
(
"Share label_embs[0] with unit_embed_tokens ..."
)
nn
.
init
.
uniform_
(
self
.
unit_embed_tokens
.
weight
)
self
.
label_embs_list
[
0
]
=
self
.
unit_embed_tokens
.
weight
### build unit encoder
self
.
unit_encoder
=
TransformerEncoderBase
(
cfg
.
text_transformer
,
unit_dictionary
,
self
.
unit_embed_tokens
,
use_rel_pos_enc
=
cfg
.
use_rel_pos_enc
,
scaling_for_att
=
cfg
.
scaling_for_att
,
)
### build text ctc head
if
self
.
add_text_ctc
:
conv
=
nn
.
Conv1d
(
cfg
.
text_transformer
.
encoder
.
embed_dim
,
cfg
.
text_transformer
.
encoder
.
embed_dim
,
self
.
text_ctc_conv_kernel
,
stride
=
self
.
text_ctc_conv_kernel
//
2
,
bias
=
False
,
padding
=
self
.
text_ctc_conv_kernel
//
2
,
)
nn
.
init
.
kaiming_normal_
(
conv
.
weight
)
self
.
unit_encoder_ctc_head
=
nn
.
Sequential
(
Rotate3D
(),
conv
,
nn
.
Dropout
(
p
=
0.1
),
nn
.
Sequential
(
Rotate3D
(),
Rotate3D
(),
LayerNorm
(
cfg
.
text_transformer
.
encoder
.
embed_dim
),
),
nn
.
GELU
(),
nn
.
Linear
(
cfg
.
text_transformer
.
encoder
.
embed_dim
,
len
(
text_tgt_dictionary
)),
)
### build unit2text decoder, not available for now
self
.
add_decoder
=
cfg
.
add_decoder
def
build_embedding
(
self
,
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
return
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
return
state_dict
@
classmethod
def
build_model
(
cls
,
cfg
:
SpeechlmConfig
,
task
:
HubertPretrainingTask
):
"""Build a new model instance."""
unit_dictionary
=
getattr
(
task
,
"text_src_dictionary"
,
None
)
text_tgt_dictionary
=
getattr
(
task
,
"text_dictionary"
,
None
)
model
=
SpeechlmModel
(
cfg
,
task
.
cfg
,
task
.
dictionaries
,
unit_dictionary
,
text_tgt_dictionary
)
return
model
def
apply_mask
(
self
,
x
,
padding_mask
,
target_list
):
B
,
T
,
C
=
x
.
shape
if
self
.
mask_prob
>
0
:
mask_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_prob
,
self
.
mask_length
,
self
.
mask_selection
,
self
.
mask_other
,
min_masks
=
2
,
no_overlap
=
self
.
no_mask_overlap
,
min_space
=
self
.
mask_min_space
,
)
mask_indices
=
torch
.
from_numpy
(
mask_indices
).
to
(
x
.
device
)
x
[
mask_indices
]
=
self
.
mask_emb
else
:
mask_indices
=
None
if
self
.
mask_channel_prob
>
0
:
mask_channel_indices
=
compute_mask_indices
(
(
B
,
C
),
None
,
self
.
mask_channel_prob
,
self
.
mask_channel_length
,
self
.
mask_channel_selection
,
self
.
mask_channel_other
,
no_overlap
=
self
.
no_mask_channel_overlap
,
min_space
=
self
.
mask_channel_min_space
,
)
mask_channel_indices
=
(
torch
.
from_numpy
(
mask_channel_indices
)
.
to
(
x
.
device
)
.
unsqueeze
(
1
)
.
expand
(
-
1
,
T
,
-
1
)
)
x
[
mask_channel_indices
]
=
0
return
x
,
mask_indices
def
forward_features
(
self
,
source
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
feature_grad_mult
>
0
:
features
=
self
.
feature_extractor
(
source
)
if
self
.
feature_grad_mult
!=
1.0
:
features
=
GradMultiply
.
apply
(
features
,
self
.
feature_grad_mult
)
else
:
with
torch
.
no_grad
():
features
=
self
.
feature_extractor
(
source
)
return
features
def
forward_targets
(
self
,
features
:
torch
.
Tensor
,
target_list
:
List
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz
=
features
.
size
(
2
)
targ_tsz
=
min
([
t
.
size
(
1
)
for
t
in
target_list
])
if
self
.
feat2tar_ratio
*
feat_tsz
>
targ_tsz
:
feat_tsz
=
int
(
targ_tsz
/
self
.
feat2tar_ratio
)
features
=
features
[...,
:
feat_tsz
]
target_inds
=
torch
.
arange
(
feat_tsz
).
float
()
*
self
.
feat2tar_ratio
target_inds
+=
np
.
random
.
choice
(
int
(
self
.
feat2tar_ratio
))
target_list
=
[
t
[:,
target_inds
.
long
()]
for
t
in
target_list
]
return
features
,
target_list
def
forward_padding_mask
(
self
,
features
:
torch
.
Tensor
,
padding_mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
extra
=
padding_mask
.
size
(
1
)
%
features
.
size
(
1
)
if
extra
>
0
:
padding_mask
=
padding_mask
[:,
:
-
extra
]
padding_mask
=
padding_mask
.
view
(
padding_mask
.
size
(
0
),
features
.
size
(
1
),
-
1
)
padding_mask
=
padding_mask
.
all
(
-
1
)
return
padding_mask
def
get_normalized_probs
(
self
,
net_output
:
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
List
[
Optional
[
Tensor
]]]]],
log_probs
:
bool
,
sample
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
,
):
lprobs
=
self
.
get_normalized_probs_scriptable
(
net_output
,
log_probs
,
sample
)
lprobs
.
batch_first
=
True
return
lprobs
def
downsample_ctc_padding_mask
(
self
,
padding_mask
):
"""
padding_mask: (B, T)
"""
stride
=
self
.
text_ctc_conv_kernel
//
2
return
padding_mask
[:,
::
stride
]
def
compute_pred
(
self
,
proj_x
,
label_embs
):
if
self
.
target_glu
:
label_embs
=
self
.
target_glu
(
label_embs
)
x
=
F
.
normalize
(
proj_x
.
float
(),
dim
=-
1
)
# (S, D)
label_embs
=
F
.
normalize
(
label_embs
.
float
(),
dim
=-
1
)
# (C, D)
logits
=
torch
.
matmul
(
x
,
label_embs
.
T
).
type_as
(
proj_x
)
# (S, C)
logits
/=
self
.
logit_temp
return
logits
def
compute_hubert_logits
(
self
,
x
,
target
,
proj
,
label_embs
,
padding_mask
,
mask_indices
):
if
not
self
.
skip_masked
:
masked_indices
=
torch
.
logical_and
(
~
padding_mask
,
mask_indices
)
proj_x_m
=
proj
(
x
[
masked_indices
])
logit_m_list
=
[(
self
.
compute_pred
(
proj_x_m
,
label_embs
),
target
[
masked_indices
])]
else
:
logit_m_list
=
[
None
]
if
not
self
.
skip_nomask
:
nomask_indices
=
torch
.
logical_and
(
~
padding_mask
,
~
mask_indices
)
proj_x_u
=
proj
(
x
[
nomask_indices
])
logit_u_list
=
[(
self
.
compute_pred
(
proj_x_u
,
label_embs
),
target
[
nomask_indices
])]
else
:
logit_u_list
=
[
None
]
return
logit_m_list
,
logit_u_list
def
convert_embeddings
(
self
,
x
,
padding_mask
,
target
=
None
,
mask_indices
=
None
,
mix_with_unit
=
False
,
use_pred_unit
=
False
,
l2_embedding
=
False
,
remask
=
False
):
"""
1. Mix with units if needed (default: True)
2. Prepare for unit_encoder inputs
Inputs:
x, (B, T, D)
Return:
src_tokens, (B, T)
soft_embeddings, (B, T, D)
l2_loss, a loss
"""
soft_embeddings
=
self
.
final_proj_list
[
0
](
x
)
if
x
.
size
(
-
1
)
==
self
.
final_dim
else
x
if
padding_mask
is
None
:
padding_mask
=
soft_embeddings
.
new_zeros
(
soft_embeddings
.
size
(
0
),
soft_embeddings
.
size
(
1
),
dtype
=
torch
.
long
)
if
use_pred_unit
:
src_tokens
=
self
.
compute_pred
(
self
.
final_proj_list
[
0
](
x
),
self
.
label_embs_list
[
0
]).
argmax
(
dim
=-
1
)
src_tokens
[
padding_mask
]
=
self
.
padding_idx
elif
target
is
not
None
:
src_tokens
=
target
else
:
src_tokens
=
padding_mask
.
long
()
if
l2_embedding
|
mix_with_unit
:
unit_embeddings
=
self
.
unit_embed_tokens
(
src_tokens
)
# (B, T, D)
l2_loss
=
0
if
l2_embedding
:
if
mask_indices
is
not
None
:
l2_loss
=
(
soft_embeddings
-
unit_embeddings
)[
mask_indices
].
float
().
pow
(
2
).
mean
(
dim
=-
1
)
scale
=
unit_embeddings
[
mask_indices
].
float
().
pow
(
2
).
sum
(
dim
=-
1
)
else
:
l2_loss
=
(
soft_embeddings
-
unit_embeddings
).
float
().
pow
(
2
).
mean
(
dim
=-
1
)
scale
=
unit_embeddings
.
float
().
pow
(
2
).
sum
(
dim
=-
1
)
l2_loss
=
(
l2_loss
/
scale
).
mean
()
if
mix_with_unit
:
B
,
T
,
D
=
x
.
shape
selected_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_prob
/
2
,
self
.
mask_length
//
2
,
self
.
mask_selection
,
self
.
mask_other
,
min_masks
=
2
,
no_overlap
=
self
.
no_mask_overlap
,
min_space
=
self
.
mask_min_space
,
)
selected_indices
=
torch
.
from_numpy
(
selected_indices
).
to
(
x
.
device
)
if
mask_indices
is
not
None
:
if
remask
:
remask_indices
=
torch
.
logical_and
(
selected_indices
,
mask_indices
)
soft_embeddings
[
remask_indices
]
=
self
.
mask_emb
swap_indices
=
torch
.
logical_and
(
selected_indices
,
~
mask_indices
)
else
:
swap_indices
=
selected_indices
soft_embeddings
[
swap_indices
]
=
unit_embeddings
[
swap_indices
]
soft_embeddings
=
soft_embeddings
*
(
1
-
padding_mask
.
unsqueeze
(
-
1
).
type_as
(
x
))
return
src_tokens
,
soft_embeddings
,
l2_loss
def
forward
(
self
,
source
:
torch
.
Tensor
=
None
,
src_tokens
:
torch
.
Tensor
=
None
,
src_lengths
:
torch
.
Tensor
=
None
,
target_list
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
bool
=
True
,
features_only
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
assert
source
is
not
None
or
src_tokens
is
not
None
if
source
is
not
None
:
return
self
.
forward_speech
(
source
=
source
,
target_list
=
target_list
,
padding_mask
=
padding_mask
,
mask
=
mask
,
features_only
=
features_only
,
output_layer
=
output_layer
,
)
else
:
return
self
.
forward_text
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lengths
,
mask
=
self
.
mask_u2t
,
output_layer
=
output_layer
,
)
def
forward_speech
(
self
,
source
:
torch
.
Tensor
=
None
,
target_list
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
bool
=
True
,
features_only
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""output layer is 1-based"""
features
=
self
.
forward_features
(
source
)
if
target_list
is
not
None
:
features
,
target_list
=
self
.
forward_targets
(
features
,
target_list
)
features_pen
=
features
.
float
().
pow
(
2
).
mean
()
features
=
features
.
transpose
(
1
,
2
)
features
=
self
.
layer_norm
(
features
)
unmasked_features
=
features
.
clone
()
if
padding_mask
is
not
None
:
padding_mask
=
self
.
forward_padding_mask
(
features
,
padding_mask
)
if
self
.
post_extract_proj
is
not
None
:
features
=
self
.
post_extract_proj
(
features
)
features
=
self
.
dropout_input
(
features
)
unmasked_features
=
self
.
dropout_features
(
unmasked_features
)
if
mask
:
x
,
mask_indices
=
self
.
apply_mask
(
features
,
padding_mask
,
target_list
)
else
:
x
=
features
mask_indices
=
None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x
,
_
=
self
.
encoder
(
x
,
padding_mask
=
padding_mask
,
layer
=
None
if
output_layer
is
None
else
output_layer
-
1
,
)
if
features_only
:
return
{
"x"
:
x
,
"padding_mask"
:
padding_mask
,
"features"
:
features
}
logit_m_list
,
logit_u_list
=
self
.
compute_hubert_logits
(
x
,
target_list
[
0
],
self
.
final_proj_list
[
0
],
self
.
label_embs_list
[
0
],
padding_mask
,
mask_indices
,
)
result
=
{
"logit_m_list"
:
logit_m_list
,
"logit_u_list"
:
logit_u_list
,
"padding_mask"
:
padding_mask
,
"features_pen"
:
features_pen
,
}
if
self
.
add_unit_encoder
:
src_tokens
,
x_emb
,
l2_loss
=
self
.
convert_embeddings
(
x
,
padding_mask
,
target_list
[
0
],
mask_indices
=
mask_indices
,
mix_with_unit
=
self
.
mix_with_unit
,
use_pred_unit
=
self
.
use_pred_unit
,
l2_embedding
=
self
.
l2_embedding
,
)
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
x_emb
)
result
[
'encoder_out'
]
=
encoder_out
[
'encoder_out'
]
# [(T, B, D)]
result
[
'encoder_padding_mask'
]
=
encoder_out
[
'encoder_padding_mask'
]
# [(B, T)]
if
self
.
l2_embedding
:
result
[
'embedding_l2_loss'
]
=
l2_loss
code_logit_m_list
,
code_logit_u_list
=
self
.
compute_hubert_logits
(
encoder_out
[
'encoder_out'
][
0
].
transpose
(
0
,
1
),
target_list
[
-
1
],
self
.
final_proj_list
[
-
1
],
self
.
label_embs_list
[
-
1
],
padding_mask
,
mask_indices
,
)
result
[
'logit_m_list'
]
+=
code_logit_m_list
result
[
'logit_u_list'
]
+=
code_logit_u_list
return
result
def
forward_text
(
self
,
src_tokens
:
torch
.
Tensor
=
None
,
src_lengths
:
torch
.
Tensor
=
None
,
target_list
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
mask
:
bool
=
True
,
output_layer
:
Optional
[
int
]
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
assert
self
.
add_unit_encoder
,
f
"Can not forward unit-text branch without unit_encoder!"
padding_mask
=
src_tokens
==
self
.
padding_idx
unit_embeddings
=
self
.
unit_embed_tokens
(
src_tokens
)
if
mask
:
unit_embeddings
,
mask_indices
=
self
.
apply_mask
(
unit_embeddings
,
padding_mask
,
[
src_tokens
])
else
:
### If already applied mask on src_tokens, then the target_list should contains many padding_idx
mask_indices
=
target_list
[
-
1
]
!=
self
.
padding_idx
unit_embeddings
[
mask_indices
]
=
self
.
mask_emb
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
unit_embeddings
,
return_all_hiddens
=
output_layer
is
not
None
,
)
result
=
{}
result
[
"encoder_out"
]
=
encoder_out
[
"encoder_out"
]
result
[
"encoder_states"
]
=
encoder_out
[
"encoder_states"
]
result
[
"padding_mask"
]
=
padding_mask
if
self
.
compute_mum
:
code_logit_m_list
,
code_logit_u_list
=
self
.
compute_hubert_logits
(
encoder_out
[
"encoder_out"
].
transpose
(
0
,
1
),
target_list
[
-
1
],
self
.
final_proj_list
[
-
1
],
self
.
label_embs_list
[
-
1
],
padding_mask
,
mask_indices
,
)
result
[
"logit_m_list"
]
=
code_logit_m_list
result
[
"logit_u_list"
]
=
code_logit_u_list
if
self
.
add_text_ctc
:
result
[
"encoder_out_ctc"
]
=
[
self
.
unit_encoder_ctc_head
(
x
)
for
x
in
encoder_out
[
'encoder_out'
]]
result
[
"encoder_padding_mask"
]
=
[
self
.
downsample_ctc_padding_mask
(
padding_mask
)
for
padding_mask
in
encoder_out
[
'encoder_padding_mask'
]
]
return
result
def
extract_features
(
self
,
source
:
torch
.
Tensor
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
bool
=
False
,
ret_conv
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Extract features for only speech input"""
res
=
self
.
forward
(
source
,
padding_mask
=
padding_mask
,
mask
=
mask
,
features_only
=
True
,
output_layer
=
output_layer
,
)
x
=
res
[
"x"
]
# B x T x D
padding_mask
=
res
[
"padding_mask"
]
if
self
.
add_unit_encoder
:
src_tokens
,
x
,
_
=
self
.
convert_embeddings
(
x
,
padding_mask
,
mix_with_unit
=
False
,
use_pred_unit
=
False
,
)
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
x
,
return_all_hiddens
=
output_layer
is
not
None
)
res
[
"x"
]
=
encoder_out
[
'encoder_out'
][
0
].
transpose
(
0
,
1
)
# (B, T, D)
feature
=
res
[
"features"
]
if
ret_conv
else
res
[
"x"
]
if
output_layer
is
not
None
:
feature
=
encoder_out
[
'encoder_states'
]
return
feature
,
padding_mask
def
get_logits
(
self
,
net_output
,
is_masked
=
True
):
if
is_masked
:
logits_list
=
net_output
[
"logit_m_list"
]
else
:
logits_list
=
net_output
[
"logit_u_list"
]
logits_list
=
[
x
[
0
].
float
()
for
x
in
logits_list
if
x
is
not
None
]
return
logits_list
def
get_targets
(
self
,
net_output
,
is_masked
=
True
):
if
is_masked
:
logits_list
=
net_output
[
"logit_m_list"
]
else
:
logits_list
=
net_output
[
"logit_u_list"
]
targets_list
=
[
x
[
1
].
long
()
for
x
in
logits_list
if
x
is
not
None
]
return
targets_list
def
get_extra_losses
(
self
,
net_output
):
extra_losses
=
[]
names
=
[]
if
"features_pen"
in
net_output
:
extra_losses
.
append
(
net_output
[
"features_pen"
])
names
.
append
(
"features_pen"
)
if
"embedding_l2_loss"
in
net_output
:
extra_losses
.
append
(
net_output
[
"embedding_l2_loss"
])
names
.
append
(
"embedding_l2_loss"
)
return
extra_losses
,
names
def
remove_pretraining_modules
(
self
,
step2
=
False
):
self
.
target_glu
=
None
def
load_checkpoint
(
self
,
checkpoint
:
str
):
if
not
PathManager
.
exists
(
checkpoint
):
raise
IOError
(
"Model file not found: {}"
.
format
(
checkpoint
))
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
checkpoint
)
return
state
class
Rotate3D
(
nn
.
Module
):
"""
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
"""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
):
return
x
.
permute
(
1
,
2
,
0
)
SpeechLM/speechlm/models/speechlm_ctcasr.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
from
dataclasses
import
dataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.tasks
import
FairseqTask
from
fairseq.models.hubert
import
HubertAsrConfig
,
HubertCtc
,
HubertEncoder
@
dataclass
class
SpeechLMCtcConfig
(
HubertAsrConfig
):
pass
@
register_model
(
"speechlm_ctc"
,
dataclass
=
SpeechLMCtcConfig
)
class
SpeechLMCtc
(
HubertCtc
):
def
__init__
(
self
,
cfg
:
SpeechLMCtcConfig
,
w2v_encoder
:
BaseFairseqModel
):
super
().
__init__
(
cfg
,
w2v_encoder
)
@
classmethod
def
build_model
(
cls
,
cfg
:
SpeechLMCtcConfig
,
task
:
FairseqTask
):
"""Build a new model instance."""
w2v_encoder
=
SpeechLMEncoder
(
cfg
,
task
)
return
cls
(
cfg
,
w2v_encoder
)
class
SpeechLMEncoder
(
HubertEncoder
):
def
__init__
(
self
,
cfg
:
HubertAsrConfig
,
task
):
super
().
__init__
(
cfg
,
task
)
if
(
task
.
target_dictionary
is
not
None
)
and
(
hasattr
(
self
.
w2v_model
,
"unit_encoder_ctc_head"
)
):
self
.
proj
=
self
.
w2v_model
.
unit_encoder_ctc_head
self
.
conv_ctc_proj
=
True
else
:
self
.
conv_ctc_proj
=
False
def
forward
(
self
,
source
,
padding_mask
,
tbc
=
True
,
**
kwargs
):
results
=
super
().
forward
(
source
,
padding_mask
,
tbc
,
**
kwargs
,
)
if
self
.
conv_ctc_proj
:
padding_mask
=
self
.
w2v_model
.
downsample_ctc_padding_mask
(
results
[
"padding_mask"
])
results
[
"encoder_padding_mask"
]
=
padding_mask
results
[
"padding_mask"
]
=
padding_mask
return
results
SpeechLM/speechlm/models/speechlm_st.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
contextlib
import
torch
import
torch.nn
as
nn
from
argparse
import
Namespace
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
from
fairseq
import
checkpoint_utils
,
tasks
,
utils
from
fairseq.models
import
FairseqEncoderDecoderModel
,
register_model
from
fairseq.models.fairseq_decoder
import
FairseqDecoder
from
fairseq.models.fairseq_encoder
import
FairseqEncoder
from
fairseq.tasks
import
FairseqTask
from
fairseq.dataclass
import
ChoiceEnum
from
fairseq.dataclass.utils
import
convert_namespace_to_omegaconf
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models.hubert
import
HubertAsrConfig
from
speechlm.modules.transformer_decoder
import
TransformerDecoderScriptable
@
dataclass
class
SpeechLMS2TConfig
(
HubertAsrConfig
):
activation_fn
:
ChoiceEnum
(
utils
.
get_available_activation_fns
())
=
field
(
default
=
"gelu"
,
metadata
=
{
"help"
:
"activation function to use"
}
)
use_rel_pos_enc
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to use relative positional encoding for decoder"
},
)
encoder_embed_dim
:
int
=
field
(
default
=
768
,
metadata
=
{
"help"
:
"encoder embedding dimension, used for enc-dec att"
}
)
decoder_embed_dim
:
int
=
field
(
default
=
768
,
metadata
=
{
"help"
:
"decoder embedding dimension"
}
)
decoder_output_dim
:
int
=
field
(
default
=
768
,
metadata
=
{
"help"
:
"decoder output dimension"
}
)
decoder_ffn_embed_dim
:
int
=
field
(
default
=
3072
,
metadata
=
{
"help"
:
"decoder embedding dimension for FFN"
}
)
decoder_layers
:
int
=
field
(
default
=
6
,
metadata
=
{
"help"
:
"num of decoder layers"
})
decoder_layerdrop
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"decoder layerdrop chance"
}
)
decoder_attention_heads
:
int
=
field
(
default
=
12
,
metadata
=
{
"help"
:
"num decoder attention heads"
}
)
decoder_learned_pos
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use learned positional embeddings in the decoder"
},
)
decoder_normalize_before
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"apply layernorm before each decoder block"
}
)
no_token_positional_embeddings
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"if set, disables positional embeddings (outside self attention)"
},
)
decoder_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout probability in the decoder"
}
)
decoder_attention_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout probability for attention weights inside the decoder"
},
)
decoder_activation_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout probability after activation in FFN inside the decoder"
},
)
share_decoder_input_output_embed
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"share decoder input and output embeddings"
}
)
### the following config is only for the compatibility to fairseq speech_to_text task
input_feat_per_channel
:
Any
=
None
input_channels
:
Any
=
None
speaker_to_id
:
Any
=
None
@
register_model
(
"speechlm_st_legacy"
,
dataclass
=
SpeechLMS2TConfig
)
class
SpeechLMS2T
(
FairseqEncoderDecoderModel
):
def
__init__
(
self
,
cfg
:
SpeechLMS2TConfig
,
encoder
:
FairseqEncoder
,
decoder
:
FairseqDecoder
):
super
().
__init__
(
encoder
,
decoder
)
self
.
cfg
=
cfg
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
return
state_dict
@
classmethod
def
build_model
(
cls
,
cfg
:
SpeechLMS2TConfig
,
task
:
FairseqTask
):
"""Build a new model instance."""
def
build_embedding
(
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
return
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
src_dict
,
tgt_dict
=
task
.
source_dictionary
,
task
.
target_dictionary
encoder
=
SpeechLMEncoder
(
cfg
,
task
)
assert
cfg
.
encoder_embed_dim
==
encoder
.
w2v_model
.
encoder
.
embedding_dim
decoder_embed_tokens
=
build_embedding
(
tgt_dict
,
cfg
.
decoder_embed_dim
)
decoder
=
TransformerDecoderScriptable
(
cfg
,
tgt_dict
,
decoder_embed_tokens
)
return
cls
(
cfg
,
encoder
,
decoder
)
class
SpeechLMEncoder
(
FairseqEncoder
):
"""
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
1. make it compatible with fairseq speech_to_text task
2. make it compatible with encoder-decoder model
"""
def
__init__
(
self
,
cfg
:
HubertAsrConfig
,
task
):
self
.
apply_mask
=
cfg
.
apply_mask
arg_overrides
=
{
"dropout"
:
cfg
.
dropout
,
"activation_dropout"
:
cfg
.
activation_dropout
,
"dropout_input"
:
cfg
.
dropout_input
,
"attention_dropout"
:
cfg
.
attention_dropout
,
"mask_length"
:
cfg
.
mask_length
,
"mask_prob"
:
cfg
.
mask_prob
,
"mask_selection"
:
cfg
.
mask_selection
,
"mask_other"
:
cfg
.
mask_other
,
"no_mask_overlap"
:
cfg
.
no_mask_overlap
,
"mask_channel_length"
:
cfg
.
mask_channel_length
,
"mask_channel_prob"
:
cfg
.
mask_channel_prob
,
"mask_channel_selection"
:
cfg
.
mask_channel_selection
,
"mask_channel_other"
:
cfg
.
mask_channel_other
,
"no_mask_channel_overlap"
:
cfg
.
no_mask_channel_overlap
,
"encoder_layerdrop"
:
cfg
.
layerdrop
,
"feature_grad_mult"
:
cfg
.
feature_grad_mult
,
}
if
cfg
.
w2v_args
is
None
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
cfg
.
w2v_path
,
arg_overrides
)
w2v_args
=
state
.
get
(
"cfg"
,
None
)
if
w2v_args
is
None
:
w2v_args
=
convert_namespace_to_omegaconf
(
state
[
"args"
])
cfg
.
w2v_args
=
w2v_args
else
:
state
=
None
w2v_args
=
cfg
.
w2v_args
if
isinstance
(
w2v_args
,
Namespace
):
cfg
.
w2v_args
=
w2v_args
=
convert_namespace_to_omegaconf
(
w2v_args
)
assert
task
.
data_cfg
.
standardize_audio
()
==
w2v_args
.
task
.
normalize
,
(
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
w2v_args
.
task
.
data
=
cfg
.
data
pretrain_task
=
tasks
.
setup_task
(
w2v_args
.
task
)
if
state
is
not
None
and
"task_state"
in
state
:
# This will load the stored "dictionaries" object
pretrain_task
.
load_state_dict
(
state
[
"task_state"
])
else
:
pretrain_task
.
load_state_dict
(
task
.
state_dict
())
model
=
pretrain_task
.
build_model
(
w2v_args
.
model
,
from_checkpoint
=
True
)
if
state
is
not
None
and
not
cfg
.
no_pretrained_weights
:
# set strict=False because we omit some modules
model
.
load_state_dict
(
state
[
"model"
],
strict
=
False
)
model
.
remove_pretraining_modules
()
super
().
__init__
(
pretrain_task
.
source_dictionary
)
d
=
w2v_args
.
model
.
encoder_embed_dim
self
.
w2v_model
=
model
self
.
final_dropout
=
nn
.
Dropout
(
cfg
.
final_dropout
)
self
.
freeze_finetune_updates
=
cfg
.
freeze_finetune_updates
self
.
num_updates
=
0
def
set_num_updates
(
self
,
num_updates
):
"""Set the number of parameters updates."""
super
().
set_num_updates
(
num_updates
)
self
.
num_updates
=
num_updates
def
forward
(
self
,
src_tokens
=
None
,
src_lengths
=
None
,
**
kwargs
):
w2v_args
=
{
"source"
:
src_tokens
,
"padding_mask"
:
lengths_to_padding_mask
(
src_lengths
),
"mask"
:
self
.
apply_mask
and
self
.
training
,
}
ft
=
self
.
freeze_finetune_updates
<=
self
.
num_updates
with
torch
.
no_grad
()
if
not
ft
else
contextlib
.
ExitStack
():
x
,
padding_mask
=
self
.
w2v_model
.
extract_features
(
**
w2v_args
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
x
=
self
.
final_dropout
(
x
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
padding_mask
],
# B x T
"padding_mask"
:
[
padding_mask
],
}
def
forward_torchscript
(
self
,
net_input
):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
_net_input
=
{
"source"
:
net_input
[
"src_tokens"
],
"padding_mask"
:
lengths_to_padding_mask
(
net_input
[
"src_lengths"
]),
"mask"
:
False
,
}
x
,
padding_mask
=
self
.
w2v_model
.
extract_features
(
**
_net_input
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
encoder_out
=
{
"encoder_out"
:
[
x
],
"encoder_padding_mask"
:
[
padding_mask
],
}
return
encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
if
encoder_out
[
"encoder_out"
]
is
not
None
:
encoder_out
[
"encoder_out"
]
=
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]
]
if
encoder_out
[
"encoder_padding_mask"
]
is
not
None
:
encoder_out
[
"encoder_padding_mask"
]
=
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_padding_mask"
]
]
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
None
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
return
state_dict
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
return
m
SpeechLM/speechlm/modules/__init__.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
from
.multihead_attention
import
MultiheadAttention
from
.relative_pos_enc
import
RelativePositionalEncoding
from
.transformer_layer
import
TransformerEncoderLayerBase
,
TransformerDecoderLayerBase
from
.w2v_encoder
import
TransformerEncoder
,
TransformerSentenceEncoderLayer
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
__all__
=
[
"MultiheadAttention"
,
"RelativePositionalEncoding"
,
"TransformerEncoderLayerBase"
,
"TransformerDecoderLayerBase"
,
"TransformerEncoder"
,
"TransformerSentenceEncoderLayer"
]
SpeechLM/speechlm/modules/learned_positional_embedding.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py
1. Add clamping if the input length exceeds the max-source-tokens
"""
from
typing
import
Dict
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
torch
import
Tensor
class
LearnedPositionalEmbedding
(
nn
.
Embedding
):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
onnx_trace
=
False
if
self
.
padding_idx
is
not
None
:
self
.
max_positions
=
self
.
num_embeddings
-
self
.
padding_idx
-
1
else
:
self
.
max_positions
=
self
.
num_embeddings
def
forward
(
self
,
input
:
Tensor
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
positions
:
Optional
[
Tensor
]
=
None
,
):
"""Input is expected to be of size [bsz x seqlen]."""
assert
(
positions
is
None
)
or
(
self
.
padding_idx
is
None
),
"If positions is pre-computed then padding_idx should not be set."
if
positions
is
None
:
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
positions
=
torch
.
zeros
(
(
1
,
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
).
fill_
(
int
(
self
.
padding_idx
+
input
.
size
(
1
)))
else
:
positions
=
utils
.
make_positions
(
input
,
self
.
padding_idx
,
onnx_trace
=
self
.
onnx_trace
)
positions
=
torch
.
clamp
(
positions
,
max
=
self
.
padding_idx
+
self
.
max_positions
)
return
F
.
embedding
(
positions
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
SpeechLM/speechlm/modules/multihead_attention.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
torch
import
Tensor
from
fairseq.modules
import
MultiheadAttention
as
FairseqMultiheadAttention
class
MultiheadAttention
(
FairseqMultiheadAttention
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
self_attention
=
False
,
encoder_decoder_attention
=
False
,
q_noise
=
0.0
,
qn_block_size
=
8
,
scaling_for_att
=
1.0
):
super
().
__init__
(
embed_dim
,
num_heads
,
kdim
,
vdim
,
dropout
,
bias
,
add_bias_kv
,
add_zero_attn
,
self_attention
,
encoder_decoder_attention
,
q_noise
,
qn_block_size
,
)
self
.
scaling_for_att
=
scaling_for_att
def
forward
(
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
need_weights
:
bool
=
True
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
before_softmax
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
position_bias
:
Optional
[
Tensor
]
=
None
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if
need_head_weights
:
need_weights
=
True
is_tpu
=
query
.
device
.
type
==
"xla"
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
src_len
=
tgt_len
assert
embed_dim
==
self
.
embed_dim
,
f
"query dim
{
embed_dim
}
!=
{
self
.
embed_dim
}
"
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
key
is
not
None
:
src_len
,
key_bsz
,
_
=
key
.
size
()
if
not
torch
.
jit
.
is_scripting
():
assert
key_bsz
==
bsz
assert
value
is
not
None
assert
src_len
,
bsz
==
value
.
shape
[:
2
]
if
(
not
self
.
onnx_trace
and
not
is_tpu
# don't use PyTorch version on TPUs
and
incremental_state
is
None
and
not
static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and
not
torch
.
jit
.
is_scripting
()
and
position_bias
is
None
):
assert
key
is
not
None
and
value
is
not
None
return
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
self
.
embed_dim
,
self
.
num_heads
,
torch
.
empty
([
0
]),
torch
.
cat
((
self
.
q_proj
.
bias
,
self
.
k_proj
.
bias
,
self
.
v_proj
.
bias
)),
self
.
bias_k
,
self
.
bias_v
,
self
.
add_zero_attn
,
self
.
dropout_module
.
p
,
self
.
out_proj
.
weight
,
self
.
out_proj
.
bias
,
self
.
training
or
self
.
dropout_module
.
apply_during_inference
,
key_padding_mask
,
need_weights
,
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
self
.
q_proj
.
weight
,
k_proj_weight
=
self
.
k_proj
.
weight
,
v_proj_weight
=
self
.
v_proj
.
weight
,
)
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
and
not
self
.
self_attention
key
=
value
=
None
else
:
saved_state
=
None
if
self
.
self_attention
:
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
query
)
v
=
self
.
v_proj
(
query
)
elif
self
.
encoder_decoder_attention
:
# encoder-decoder attention
q
=
self
.
q_proj
(
query
)
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
key
)
else
:
assert
key
is
not
None
and
value
is
not
None
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
value
)
q
*=
self
.
scaling
q
*=
(
1
/
self
.
scaling_for_att
)
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
k
=
torch
.
cat
([
k
,
self
.
bias_k
.
repeat
(
1
,
bsz
,
1
)])
v
=
torch
.
cat
([
v
,
self
.
bias_v
.
repeat
(
1
,
bsz
,
1
)])
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
),
],
dim
=
1
,
)
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
saved_state
is
not
None
:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if
"prev_key"
in
saved_state
:
_prev_key
=
saved_state
[
"prev_key"
]
assert
_prev_key
is
not
None
prev_key
=
_prev_key
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
k
=
prev_key
else
:
assert
k
is
not
None
k
=
torch
.
cat
([
prev_key
,
k
],
dim
=
1
)
src_len
=
k
.
size
(
1
)
if
"prev_value"
in
saved_state
:
_prev_value
=
saved_state
[
"prev_value"
]
assert
_prev_value
is
not
None
prev_value
=
_prev_value
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
v
=
prev_value
else
:
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
None
if
"prev_key_padding_mask"
in
saved_state
:
prev_key_padding_mask
=
saved_state
[
"prev_key_padding_mask"
]
assert
k
is
not
None
and
v
is
not
None
key_padding_mask
=
MultiheadAttention
.
_append_prev_key_padding_mask
(
key_padding_mask
=
key_padding_mask
,
prev_key_padding_mask
=
prev_key_padding_mask
,
batch_size
=
bsz
,
src_len
=
k
.
size
(
1
),
static_kv
=
static_kv
,
)
saved_state
[
"prev_key"
]
=
k
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_value"
]
=
v
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_key_padding_mask"
]
=
key_padding_mask
# In this branch incremental_state is never None
assert
incremental_state
is
not
None
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
assert
k
is
not
None
assert
k
.
size
(
1
)
==
src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if
key_padding_mask
is
not
None
and
key_padding_mask
.
dim
()
==
0
:
key_padding_mask
=
None
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
if
self
.
add_zero_attn
:
assert
v
is
not
None
src_len
+=
1
k
=
torch
.
cat
([
k
,
k
.
new_zeros
((
k
.
size
(
0
),
1
)
+
k
.
size
()[
2
:])],
dim
=
1
)
v
=
torch
.
cat
([
v
,
v
.
new_zeros
((
v
.
size
(
0
),
1
)
+
v
.
size
()[
2
:])],
dim
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
torch
.
zeros
(
key_padding_mask
.
size
(
0
),
1
).
type_as
(
key_padding_mask
),
],
dim
=
1
,
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
if
position_bias
is
not
None
:
## first order
## position_bias: [241, 241, 64]
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
reshape_q
=
q
.
contiguous
().
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
).
transpose
(
0
,
1
)
#[241, 492, 64]
#print ("reshape_q: ", reshape_q.size())
B
=
torch
.
matmul
(
reshape_q
,
position_bias
.
transpose
(
-
2
,
-
1
))
#print ("B: ", B.size()) ## [241, 492, 241]
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
B
=
B
.
transpose
(
0
,
1
).
view
(
bsz
*
self
.
num_heads
,
position_bias
.
size
(
0
),
position_bias
.
size
(
1
))
#print ("B 2: ", B.size())
attn_weights
+=
B
attn_weights
*=
self
.
scaling_for_att
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
if
self
.
onnx_trace
:
attn_mask
=
attn_mask
.
repeat
(
attn_weights
.
size
(
0
),
1
,
1
)
attn_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
if
not
is_tpu
:
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
),
float
(
"-inf"
),
)
else
:
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
self
.
scaling_for_att
>
1.0
:
attn_weights
=
attn_weights
-
attn_weights
.
detach
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
]
if
before_softmax
:
return
attn_weights
,
v
attn_weights_float
=
utils
.
softmax
(
attn_weights
,
dim
=-
1
,
onnx_trace
=
self
.
onnx_trace
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
assert
v
is
not
None
attn
=
torch
.
bmm
(
attn_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
if
self
.
onnx_trace
and
attn
.
size
(
1
)
==
1
:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn
=
attn
.
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
else
:
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
attn_weights
:
Optional
[
Tensor
]
=
None
if
need_weights
:
attn_weights
=
attn_weights_float
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
).
transpose
(
1
,
0
)
if
not
need_head_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
mean
(
dim
=
0
)
return
attn
,
attn_weights
SpeechLM/speechlm/modules/relative_pos_enc.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import
torch
class
RelativePositionalEncoding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model
,
maxlen
=
1000
,
embed_v
=
False
):
super
(
RelativePositionalEncoding
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
maxlen
=
maxlen
self
.
pe_k
=
torch
.
nn
.
Embedding
(
2
*
maxlen
,
d_model
)
if
embed_v
:
self
.
pe_v
=
torch
.
nn
.
Embedding
(
2
*
maxlen
,
d_model
)
self
.
embed_v
=
embed_v
def
forward
(
self
,
pos_seq
,
incremental_state
=
None
):
pos_seq
[
pos_seq
<
-
self
.
maxlen
]
=
-
self
.
maxlen
pos_seq
[
pos_seq
>=
self
.
maxlen
]
=
self
.
maxlen
-
1
pos_seq
=
pos_seq
+
self
.
maxlen
if
incremental_state
is
not
None
:
pos_seq
=
pos_seq
[
-
1
:]
if
self
.
embed_v
:
return
self
.
pe_k
(
pos_seq
),
self
.
pe_v
(
pos_seq
)
else
:
return
self
.
pe_k
(
pos_seq
),
None
SpeechLM/speechlm/modules/transformer_decoder.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/transformer/transformer_decoder.py
"""
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models
import
FairseqIncrementalDecoder
from
fairseq.models.transformer
import
TransformerConfig
from
fairseq.modules
import
(
AdaptiveSoftmax
,
BaseLayer
,
FairseqDropout
,
LayerDropModuleList
,
LayerNorm
,
PositionalEmbedding
,
SinusoidalPositionalEmbedding
,
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
torch
import
Tensor
from
speechlm.modules
import
transformer_layer
from
speechlm.modules.relative_pos_enc
import
RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def
module_name_fordropout
(
module_name
:
str
)
->
str
:
if
module_name
==
"TransformerDecoderBase"
:
return
"TransformerDecoder"
else
:
return
module_name
class
TransformerDecoderBase
(
FairseqIncrementalDecoder
):
"""
Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
cfg
,
dictionary
,
embed_tokens
,
no_encoder_attn
=
False
,
output_projection
=
None
,
use_rel_pos_enc
=
False
,
):
self
.
cfg
=
cfg
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
self
.
_future_mask
=
torch
.
empty
(
0
)
self
.
dropout_module
=
FairseqDropout
(
cfg
.
dropout
,
module_name
=
module_name_fordropout
(
self
.
__class__
.
__name__
)
)
self
.
decoder_layerdrop
=
cfg
.
decoder
.
layerdrop
self
.
share_input_output_embed
=
cfg
.
share_decoder_input_output_embed
input_embed_dim
=
embed_tokens
.
embedding_dim
embed_dim
=
cfg
.
decoder
.
embed_dim
self
.
embed_dim
=
embed_dim
self
.
output_embed_dim
=
cfg
.
decoder
.
output_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_target_positions
=
cfg
.
max_target_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
1.0
if
cfg
.
no_scale_embedding
else
math
.
sqrt
(
embed_dim
)
if
not
cfg
.
adaptive_input
and
cfg
.
quant_noise
.
pq
>
0
:
self
.
quant_noise
=
apply_quant_noise_
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
),
cfg
.
quant_noise
.
pq
,
cfg
.
quant_noise
.
pq_block_size
,
)
else
:
self
.
quant_noise
=
None
self
.
project_in_dim
=
(
Linear
(
input_embed_dim
,
embed_dim
,
bias
=
False
)
if
embed_dim
!=
input_embed_dim
else
None
)
self
.
embed_positions
=
(
PositionalEmbedding
(
self
.
max_target_positions
,
embed_dim
,
self
.
padding_idx
,
learned
=
cfg
.
decoder
.
learned_pos
,
)
if
not
cfg
.
no_token_positional_embeddings
else
None
)
if
cfg
.
layernorm_embedding
:
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layernorm_embedding
=
None
self
.
cross_self_attention
=
cfg
.
cross_self_attention
if
self
.
decoder_layerdrop
>
0.0
:
self
.
layers
=
LayerDropModuleList
(
p
=
self
.
decoder_layerdrop
)
else
:
self
.
layers
=
nn
.
ModuleList
([])
self
.
use_rel_pos_enc
=
use_rel_pos_enc
self
.
layers
.
extend
(
[
self
.
build_decoder_layer
(
cfg
,
no_encoder_attn
)
for
_
in
range
(
cfg
.
decoder
.
layers
)
]
)
self
.
num_layers
=
len
(
self
.
layers
)
if
cfg
.
decoder
.
normalize_before
and
not
cfg
.
no_decoder_final_norm
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layer_norm
=
None
self
.
project_out_dim
=
(
Linear
(
embed_dim
,
self
.
output_embed_dim
,
bias
=
False
)
if
embed_dim
!=
self
.
output_embed_dim
and
not
cfg
.
tie_adaptive_weights
else
None
)
self
.
adaptive_softmax
=
None
self
.
output_projection
=
output_projection
if
self
.
output_projection
is
None
:
self
.
build_output_projection
(
cfg
,
dictionary
,
embed_tokens
)
if
self
.
use_rel_pos_enc
:
self
.
pos_emb
=
RelativePositionalEncoding
(
embed_dim
//
cfg
.
decoder
.
attention_heads
,
24
)
def
build_output_projection
(
self
,
cfg
,
dictionary
,
embed_tokens
):
if
cfg
.
adaptive_softmax_cutoff
is
not
None
:
self
.
adaptive_softmax
=
AdaptiveSoftmax
(
len
(
dictionary
),
self
.
output_embed_dim
,
utils
.
eval_str_list
(
cfg
.
adaptive_softmax_cutoff
,
type
=
int
),
dropout
=
cfg
.
adaptive_softmax_dropout
,
adaptive_inputs
=
embed_tokens
if
cfg
.
tie_adaptive_weights
else
None
,
factor
=
cfg
.
adaptive_softmax_factor
,
tie_proj
=
cfg
.
tie_adaptive_proj
,
)
elif
self
.
share_input_output_embed
:
self
.
output_projection
=
nn
.
Linear
(
self
.
embed_tokens
.
weight
.
shape
[
1
],
self
.
embed_tokens
.
weight
.
shape
[
0
],
bias
=
False
,
)
self
.
output_projection
.
weight
=
self
.
embed_tokens
.
weight
else
:
self
.
output_projection
=
nn
.
Linear
(
self
.
output_embed_dim
,
len
(
dictionary
),
bias
=
False
)
nn
.
init
.
normal_
(
self
.
output_projection
.
weight
,
mean
=
0
,
std
=
self
.
output_embed_dim
**
-
0.5
)
num_base_layers
=
cfg
.
base_layers
for
i
in
range
(
num_base_layers
):
self
.
layers
.
insert
(
((
i
+
1
)
*
cfg
.
decoder
.
layers
)
//
(
num_base_layers
+
1
),
BaseLayer
(
cfg
),
)
def
build_decoder_layer
(
self
,
cfg
,
no_encoder_attn
=
False
):
layer
=
transformer_layer
.
TransformerDecoderLayerBase
(
cfg
,
no_encoder_attn
,
has_relative_attention_bias
=
self
.
use_rel_pos_enc
)
checkpoint
=
cfg
.
checkpoint_activations
if
checkpoint
:
offload_to_cpu
=
cfg
.
offload_activations
layer
=
checkpoint_wrapper
(
layer
,
offload_to_cpu
=
offload_to_cpu
)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap
=
cfg
.
min_params_to_wrap
if
not
checkpoint
else
0
layer
=
fsdp_wrap
(
layer
,
min_num_params
=
min_params_to_wrap
)
return
layer
def
forward
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
features_only
:
bool
=
False
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
src_lengths
:
Optional
[
Any
]
=
None
,
return_all_hiddens
:
bool
=
False
,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x
,
extra
=
self
.
extract_features
(
prev_output_tokens
,
encoder_out
=
encoder_out
,
incremental_state
=
incremental_state
,
full_context_alignment
=
full_context_alignment
,
alignment_layer
=
alignment_layer
,
alignment_heads
=
alignment_heads
,
)
if
not
features_only
:
x
=
self
.
output_layer
(
x
)
return
x
,
extra
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
return
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def
extract_features_scriptable
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs
,
slen
=
prev_output_tokens
.
size
()
if
alignment_layer
is
None
:
alignment_layer
=
self
.
num_layers
-
1
enc
:
Optional
[
Tensor
]
=
None
padding_mask
:
Optional
[
Tensor
]
=
None
if
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_out"
])
>
0
:
enc
=
encoder_out
[
"encoder_out"
][
0
]
assert
(
enc
.
size
()[
1
]
==
bs
),
f
"Expected enc.shape == (t,
{
bs
}
, c) got
{
enc
.
shape
}
"
if
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
:
padding_mask
=
encoder_out
[
"encoder_padding_mask"
][
0
]
# embed positions
positions
=
None
if
self
.
embed_positions
is
not
None
:
positions
=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
)
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
if
positions
is
not
None
:
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
self
.
quant_noise
is
not
None
:
x
=
self
.
quant_noise
(
x
)
if
self
.
project_in_dim
is
not
None
:
x
=
self
.
project_in_dim
(
x
)
if
positions
is
not
None
:
x
+=
positions
if
self
.
layernorm_embedding
is
not
None
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
dropout_module
(
x
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
use_rel_pos_enc
:
pos_seq
=
torch
.
arange
(
0
,
slen
).
long
().
to
(
x
.
device
)
pos_seq
=
pos_seq
[:,
None
]
-
pos_seq
[
None
,
:]
pos_k
,
_
=
self
.
pos_emb
(
pos_seq
,
incremental_state
)
else
:
pos_k
=
None
self_attn_padding_mask
:
Optional
[
Tensor
]
=
None
if
self
.
cross_self_attention
or
prev_output_tokens
.
eq
(
self
.
padding_idx
).
any
():
self_attn_padding_mask
=
prev_output_tokens
.
eq
(
self
.
padding_idx
)
# decoder layers
attn
:
Optional
[
Tensor
]
=
None
inner_states
:
List
[
Optional
[
Tensor
]]
=
[
x
]
for
idx
,
layer
in
enumerate
(
self
.
layers
):
if
incremental_state
is
None
and
not
full_context_alignment
:
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
else
:
self_attn_mask
=
None
x
,
layer_attn
,
_
=
layer
(
x
,
enc
,
padding_mask
,
incremental_state
,
self_attn_mask
=
self_attn_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
need_attn
=
bool
((
idx
==
alignment_layer
)),
need_head_weights
=
bool
((
idx
==
alignment_layer
)),
pos_bias
=
pos_k
,
)
inner_states
.
append
(
x
)
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
attn
=
layer_attn
.
float
().
to
(
x
)
if
attn
is
not
None
:
if
alignment_heads
is
not
None
:
attn
=
attn
[:
alignment_heads
]
# average probabilities over heads
attn
=
attn
.
mean
(
dim
=
0
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
project_out_dim
is
not
None
:
x
=
self
.
project_out_dim
(
x
)
return
x
,
{
"attn"
:
[
attn
],
"inner_states"
:
inner_states
}
def
output_layer
(
self
,
features
):
"""Project features to the vocabulary size."""
if
self
.
adaptive_softmax
is
None
:
# project back to size of vocabulary
return
self
.
output_projection
(
features
)
else
:
return
features
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
if
self
.
embed_positions
is
None
:
return
self
.
max_target_positions
return
min
(
self
.
max_target_positions
,
self
.
embed_positions
.
max_positions
)
def
buffered_future_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
0
)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if
(
self
.
_future_mask
.
size
(
0
)
==
0
or
(
not
self
.
_future_mask
.
device
==
tensor
.
device
)
or
self
.
_future_mask
.
size
(
0
)
<
dim
):
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
torch
.
zeros
([
dim
,
dim
])),
1
)
self
.
_future_mask
=
self
.
_future_mask
.
to
(
tensor
)
return
self
.
_future_mask
[:
dim
,
:
dim
]
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
weights_key
=
"{}.embed_positions.weights"
.
format
(
name
)
if
weights_key
in
state_dict
:
del
state_dict
[
weights_key
]
state_dict
[
"{}.embed_positions._float_tensor"
.
format
(
name
)
]
=
torch
.
FloatTensor
(
1
)
if
f
"
{
name
}
.output_projection.weight"
not
in
state_dict
:
if
self
.
share_input_output_embed
:
embed_out_key
=
f
"
{
name
}
.embed_tokens.weight"
else
:
embed_out_key
=
f
"
{
name
}
.embed_out"
if
embed_out_key
in
state_dict
:
state_dict
[
f
"
{
name
}
.output_projection.weight"
]
=
state_dict
[
embed_out_key
]
if
not
self
.
share_input_output_embed
:
del
state_dict
[
embed_out_key
]
for
i
in
range
(
self
.
num_layers
):
# update layer norms
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"encoder_attn_layer_norm"
,
"2"
:
"final_layer_norm"
,
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layers.{}.layer_norms.{}.{}"
.
format
(
name
,
i
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.layers.{}.{}.{}"
.
format
(
name
,
i
,
new
,
m
)
]
=
state_dict
[
k
]
del
state_dict
[
k
]
version_key
=
"{}.version"
.
format
(
name
)
if
utils
.
item
(
state_dict
.
get
(
version_key
,
torch
.
Tensor
([
1
]))[
0
])
<=
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
version_key
]
=
torch
.
Tensor
([
1
])
return
state_dict
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
return
m
class
TransformerDecoderBaseScriptable
(
TransformerDecoderBase
):
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
# call scriptable method from parent class
x
,
_
=
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
return
x
,
None
class
TransformerDecoder
(
TransformerDecoderBase
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
no_encoder_attn
=
False
,
output_projection
=
None
,
):
self
.
args
=
args
super
().
__init__
(
TransformerConfig
.
from_namespace
(
args
),
dictionary
,
embed_tokens
,
no_encoder_attn
=
no_encoder_attn
,
output_projection
=
output_projection
,
use_rel_pos_enc
=
getattr
(
args
,
"use_rel_pos_enc"
,
False
),
)
def
build_output_projection
(
self
,
args
,
dictionary
,
embed_tokens
):
super
().
build_output_projection
(
TransformerConfig
.
from_namespace
(
args
),
dictionary
,
embed_tokens
)
def
build_decoder_layer
(
self
,
args
,
no_encoder_attn
=
False
):
return
super
().
build_decoder_layer
(
TransformerConfig
.
from_namespace
(
args
),
no_encoder_attn
=
no_encoder_attn
)
class
TransformerDecoderScriptable
(
TransformerDecoder
):
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
# call scriptable method from parent class
x
,
_
=
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
return
x
,
None
SpeechLM/speechlm/modules/transformer_encoder.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
math
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models
import
FairseqEncoder
from
fairseq.modules
import
(
FairseqDropout
,
LayerDropModuleList
,
LayerNorm
,
SinusoidalPositionalEmbedding
,
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
torch
import
Tensor
from
fairseq.models.transformer
import
(
TransformerConfig
,
)
from
speechlm.modules
import
transformer_layer
,
LearnedPositionalEmbedding
from
speechlm.modules.relative_pos_enc
import
RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def
module_name_fordropout
(
module_name
:
str
)
->
str
:
if
module_name
==
"TransformerEncoderBase"
:
return
"TransformerEncoder"
else
:
return
module_name
class
TransformerEncoderBase
(
FairseqEncoder
):
"""
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def
__init__
(
self
,
cfg
,
dictionary
,
embed_tokens
,
use_rel_pos_enc
=
False
,
scaling_for_att
=
1.0
):
self
.
cfg
=
cfg
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
self
.
dropout_module
=
FairseqDropout
(
cfg
.
dropout
,
module_name
=
module_name_fordropout
(
self
.
__class__
.
__name__
)
)
self
.
encoder_layerdrop
=
cfg
.
encoder
.
layerdrop
embed_dim
=
embed_tokens
.
embedding_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_source_positions
=
cfg
.
max_source_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
1.0
if
cfg
.
no_scale_embedding
else
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
(
PositionalEmbedding
(
cfg
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
,
learned
=
cfg
.
encoder
.
learned_pos
,
)
if
not
cfg
.
no_token_positional_embeddings
else
None
)
if
cfg
.
layernorm_embedding
:
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layernorm_embedding
=
None
if
not
cfg
.
adaptive_input
and
cfg
.
quant_noise
.
pq
>
0
:
self
.
quant_noise
=
apply_quant_noise_
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
),
cfg
.
quant_noise
.
pq
,
cfg
.
quant_noise
.
pq_block_size
,
)
else
:
self
.
quant_noise
=
None
if
self
.
encoder_layerdrop
>
0.0
:
self
.
layers
=
LayerDropModuleList
(
p
=
self
.
encoder_layerdrop
)
else
:
self
.
layers
=
nn
.
ModuleList
([])
self
.
use_rel_pos_enc
=
use_rel_pos_enc
self
.
scaling_for_att
=
scaling_for_att
self
.
layers
.
extend
(
[
self
.
build_encoder_layer
(
cfg
)
for
i
in
range
(
cfg
.
encoder
.
layers
)]
)
self
.
num_layers
=
len
(
self
.
layers
)
if
cfg
.
encoder
.
normalize_before
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layer_norm
=
None
if
self
.
use_rel_pos_enc
:
self
.
pos_emb
=
RelativePositionalEncoding
(
embed_dim
//
cfg
.
encoder
.
attention_heads
,
160
)
def
build_encoder_layer
(
self
,
cfg
):
layer
=
transformer_layer
.
TransformerEncoderLayerBase
(
cfg
,
has_relative_attention_bias
=
self
.
use_rel_pos_enc
,
scaling_for_att
=
self
.
scaling_for_att
)
checkpoint
=
cfg
.
checkpoint_activations
if
checkpoint
:
offload_to_cpu
=
cfg
.
offload_activations
layer
=
checkpoint_wrapper
(
layer
,
offload_to_cpu
=
offload_to_cpu
)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap
=
cfg
.
min_params_to_wrap
if
not
checkpoint
else
0
layer
=
fsdp_wrap
(
layer
,
min_num_params
=
min_params_to_wrap
)
return
layer
def
forward_embedding
(
self
,
src_tokens
,
token_embedding
:
Optional
[
torch
.
Tensor
]
=
None
):
# embed tokens and positions
if
token_embedding
is
None
:
token_embedding
=
self
.
embed_tokens
(
src_tokens
)
x
=
embed
=
self
.
embed_scale
*
token_embedding
if
self
.
embed_positions
is
not
None
:
x
=
embed
+
self
.
embed_positions
(
src_tokens
)
if
self
.
layernorm_embedding
is
not
None
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
dropout_module
(
x
)
if
self
.
quant_noise
is
not
None
:
x
=
self
.
quant_noise
(
x
)
return
x
,
embed
def
forward
(
self
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
uniformity_layers
:
Optional
[
List
[
int
]]
=
None
,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return
self
.
forward_scriptable
(
src_tokens
,
src_lengths
,
return_all_hiddens
,
token_embeddings
,
uniformity_layers
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def
forward_scriptable
(
self
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
uniformity_layers
:
Optional
[
List
[
int
]]
=
None
,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
has_pads
=
src_tokens
.
device
.
type
==
"xla"
or
encoder_padding_mask
.
any
()
x
,
encoder_embedding
=
self
.
forward_embedding
(
src_tokens
,
token_embeddings
)
# account for padding while computing the representation
if
has_pads
:
x
=
x
*
(
1
-
encoder_padding_mask
.
unsqueeze
(
-
1
).
type_as
(
x
))
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
use_rel_pos_enc
:
x_len
=
x
.
shape
[
0
]
pos_seq
=
torch
.
arange
(
0
,
x_len
).
long
().
to
(
x
.
device
)
pos_seq
=
pos_seq
[:,
None
]
-
pos_seq
[
None
,
:]
pos_k
,
pos_v
=
self
.
pos_emb
(
pos_seq
)
else
:
pos_k
=
None
encoder_states
=
[]
uniformity_hiddens
=
[]
if
return_all_hiddens
:
encoder_states
.
append
(
x
)
if
uniformity_layers
is
not
None
and
0
in
uniformity_layers
:
x
=
F
.
normalize
(
x
.
float
(),
dim
=-
1
).
type_as
(
x
)
uniformity_hiddens
.
append
(
x
)
# encoder layers
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
encoder_padding_mask
=
encoder_padding_mask
if
has_pads
else
None
,
pos_bias
=
pos_k
,
)
if
uniformity_layers
is
not
None
and
i
+
1
in
uniformity_layers
:
x
=
F
.
normalize
(
x
.
float
(),
dim
=-
1
).
type_as
(
x
)
uniformity_hiddens
.
append
(
x
)
if
return_all_hiddens
:
assert
encoder_states
is
not
None
encoder_states
.
append
(
x
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths
=
(
src_tokens
.
ne
(
self
.
padding_idx
)
.
sum
(
dim
=
1
,
dtype
=
torch
.
int32
)
.
reshape
(
-
1
,
1
)
.
contiguous
()
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"uniformity_hiddens"
:
uniformity_hiddens
,
# List[T x B x C]
"src_tokens"
:
[],
"src_lengths"
:
[
src_lengths
],
}
@
torch
.
jit
.
export
def
reorder_encoder_out
(
self
,
encoder_out
:
Dict
[
str
,
List
[
Tensor
]],
new_order
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if
len
(
encoder_out
[
"encoder_out"
])
==
0
:
new_encoder_out
=
[]
else
:
new_encoder_out
=
[
encoder_out
[
"encoder_out"
][
0
].
index_select
(
1
,
new_order
)]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
:
new_encoder_padding_mask
=
[]
else
:
new_encoder_padding_mask
=
[
encoder_out
[
"encoder_padding_mask"
][
0
].
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"encoder_embedding"
])
==
0
:
new_encoder_embedding
=
[]
else
:
new_encoder_embedding
=
[
encoder_out
[
"encoder_embedding"
][
0
].
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"src_tokens"
])
==
0
:
src_tokens
=
[]
else
:
src_tokens
=
[(
encoder_out
[
"src_tokens"
][
0
]).
index_select
(
0
,
new_order
)]
if
len
(
encoder_out
[
"src_lengths"
])
==
0
:
src_lengths
=
[]
else
:
src_lengths
=
[(
encoder_out
[
"src_lengths"
][
0
]).
index_select
(
0
,
new_order
)]
encoder_states
=
encoder_out
[
"encoder_states"
]
if
len
(
encoder_states
)
>
0
:
for
idx
,
state
in
enumerate
(
encoder_states
):
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
{
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
src_tokens
,
# B x T
"src_lengths"
:
src_lengths
,
# B x 1
}
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
if
self
.
embed_positions
is
None
:
return
self
.
max_source_positions
return
min
(
self
.
max_source_positions
,
self
.
embed_positions
.
max_positions
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
weights_key
=
"{}.embed_positions.weights"
.
format
(
name
)
if
weights_key
in
state_dict
:
print
(
"deleting {0}"
.
format
(
weights_key
))
del
state_dict
[
weights_key
]
state_dict
[
"{}.embed_positions._float_tensor"
.
format
(
name
)
]
=
torch
.
FloatTensor
(
1
)
for
i
in
range
(
self
.
num_layers
):
# update layer norms
self
.
layers
[
i
].
upgrade_state_dict_named
(
state_dict
,
"{}.layers.{}"
.
format
(
name
,
i
)
)
version_key
=
"{}.version"
.
format
(
name
)
if
utils
.
item
(
state_dict
.
get
(
version_key
,
torch
.
Tensor
([
1
]))[
0
])
<
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
version_key
]
=
torch
.
Tensor
([
1
])
return
state_dict
class
TransformerEncoder
(
TransformerEncoderBase
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
self
.
args
=
args
super
().
__init__
(
TransformerConfig
.
from_namespace
(
args
),
dictionary
,
embed_tokens
,
use_rel_pos_enc
=
getattr
(
args
,
"use_rel_pos_enc"
,
False
),
scaling_for_att
=
getattr
(
args
,
"scaling_for_att"
,
1.0
),
)
def
build_encoder_layer
(
self
,
args
):
return
super
().
build_encoder_layer
(
TransformerConfig
.
from_namespace
(
args
),
)
def
PositionalEmbedding
(
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
learned
:
bool
=
False
,
):
if
learned
:
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if
padding_idx
is
not
None
:
num_embeddings
=
num_embeddings
+
padding_idx
+
1
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
if
padding_idx
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
init_size
=
num_embeddings
+
padding_idx
+
1
,
)
return
m
SpeechLM/speechlm/modules/transformer_layer.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_layer.py
https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
"""
from
typing
import
Dict
,
List
,
Optional
import
torch
from
torch
import
Tensor
from
fairseq.modules
import
LayerNorm
from
speechlm.modules.multihead_attention
import
MultiheadAttention
from
fairseq.modules.transformer_layer
import
TransformerEncoderLayerBase
as
FairseqTransformerEncoderLayerBase
from
fairseq.modules.transformer_layer
import
TransformerDecoderLayerBase
as
FairseqTransformerDecoderLayerBase
class
TransformerEncoderLayerBase
(
FairseqTransformerEncoderLayerBase
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.encoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
cfg
,
has_relative_attention_bias
=
False
,
scaling_for_att
=
1.0
):
self
.
scaling_for_att
=
scaling_for_att
super
().
__init__
(
cfg
)
if
has_relative_attention_bias
:
self
.
norm_k
=
LayerNorm
(
self
.
embed_dim
//
cfg
.
encoder
.
attention_heads
)
def
build_self_attention
(
self
,
embed_dim
,
cfg
,
scaling_for_att
=
1.0
):
return
MultiheadAttention
(
embed_dim
,
cfg
.
encoder
.
attention_heads
,
dropout
=
cfg
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
scaling_for_att
=
self
.
scaling_for_att
,
)
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_bias
=
None
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
.
to
(
torch
.
bool
),
-
1e8
if
x
.
dtype
==
torch
.
float32
else
-
1e4
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
pos_bias
is
not
None
:
pos_bias
=
self
.
norm_k
(
pos_bias
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
position_bias
=
pos_bias
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
return
x
class
TransformerDecoderLayerBase
(
FairseqTransformerDecoderLayerBase
):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.decoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
cfg
,
no_encoder_attn
=
False
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
has_relative_attention_bias
=
False
,
scaling_for_att
=
1.0
,
):
self
.
scaling_for_att
=
scaling_for_att
super
().
__init__
(
cfg
,
no_encoder_attn
,
add_bias_kv
,
add_zero_attn
,
)
if
has_relative_attention_bias
:
self
.
norm_k
=
LayerNorm
(
self
.
embed_dim
//
cfg
.
decoder
.
attention_heads
)
def
build_self_attention
(
self
,
embed_dim
,
cfg
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
return
MultiheadAttention
(
embed_dim
,
cfg
.
decoder
.
attention_heads
,
dropout
=
cfg
.
attention_dropout
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
self_attention
=
not
cfg
.
cross_self_attention
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
scaling_for_att
=
self
.
scaling_for_att
,
)
def
build_encoder_attention
(
self
,
embed_dim
,
cfg
):
return
MultiheadAttention
(
embed_dim
,
cfg
.
decoder
.
attention_heads
,
kdim
=
cfg
.
encoder
.
embed_dim
,
vdim
=
cfg
.
encoder
.
embed_dim
,
dropout
=
cfg
.
attention_dropout
,
encoder_decoder_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
scaling_for_att
=
self
.
scaling_for_att
,
)
def
forward
(
self
,
x
,
encoder_out
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
prev_self_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
prev_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
self_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self_attn_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_attn
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
pos_bias
=
None
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if
need_head_weights
:
need_attn
=
True
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
pos_bias
is
not
None
:
pos_bias
=
self
.
norm_k
(
pos_bias
)
if
prev_self_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_self_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_self_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_self_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
self_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
_self_attn_input_buffer
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
if
self
.
cross_self_attention
and
not
(
incremental_state
is
not
None
and
_self_attn_input_buffer
is
not
None
and
"prev_key"
in
_self_attn_input_buffer
):
if
self_attn_mask
is
not
None
:
assert
encoder_out
is
not
None
self_attn_mask
=
torch
.
cat
(
(
x
.
new_zeros
(
x
.
size
(
0
),
encoder_out
.
size
(
0
)),
self_attn_mask
),
dim
=
1
)
if
self_attn_padding_mask
is
not
None
:
if
encoder_padding_mask
is
None
:
assert
encoder_out
is
not
None
encoder_padding_mask
=
self_attn_padding_mask
.
new_zeros
(
encoder_out
.
size
(
1
),
encoder_out
.
size
(
0
)
)
self_attn_padding_mask
=
torch
.
cat
(
(
encoder_padding_mask
,
self_attn_padding_mask
),
dim
=
1
)
assert
encoder_out
is
not
None
y
=
torch
.
cat
((
encoder_out
,
x
),
dim
=
0
)
else
:
y
=
x
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
position_bias
=
pos_bias
,
)
if
self
.
c_attn
is
not
None
:
tgt_len
,
bsz
=
x
.
size
(
0
),
x
.
size
(
1
)
x
=
x
.
view
(
tgt_len
,
bsz
,
self
.
nh
,
self
.
head_dim
)
x
=
torch
.
einsum
(
"tbhd,h->tbhd"
,
x
,
self
.
c_attn
)
x
=
x
.
reshape
(
tgt_len
,
bsz
,
self
.
embed_dim
)
if
self
.
attn_ln
is
not
None
:
x
=
self
.
attn_ln
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
encoder_attn
is
not
None
and
encoder_out
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
if
prev_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
encoder_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
if
self
.
ffn_layernorm
is
not
None
:
x
=
self
.
ffn_layernorm
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
if
self
.
w_resid
is
not
None
:
residual
=
torch
.
mul
(
self
.
w_resid
,
residual
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
if
self
.
onnx_trace
and
incremental_state
is
not
None
:
saved_state
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
assert
saved_state
is
not
None
if
self_attn_padding_mask
is
not
None
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
],
saved_state
[
"prev_key_padding_mask"
],
]
else
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
]]
return
x
,
attn
,
self_attn_state
return
x
,
attn
,
None
def
make_generation_fast_
(
self
,
need_attn
:
bool
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
SpeechLM/speechlm/modules/w2v_encoder.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
wav2vec encoder adding relitive position bias, modified from
https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_encoder.py
https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
"""
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.dataclass
import
ChoiceEnum
from
fairseq.modules
import
(
LayerNorm
,
SamePad
,
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
from
fairseq.utils
import
index_put
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models.wav2vec.utils
import
pad_to_multiple
## reload multi-head attition with rel-pos-bias
from
fairseq.models.wav2vec.wav2vec2
import
TransformerEncoder
as
W2vTransformerEncoder
from
speechlm.modules.relative_pos_enc
import
RelativePositionalEncoding
from
speechlm.modules.multihead_attention
import
MultiheadAttention
EXTRACTOR_MODE_CHOICES
=
ChoiceEnum
([
"default"
,
"layer_norm"
])
MASKING_DISTRIBUTION_CHOICES
=
ChoiceEnum
([
"static"
,
"uniform"
,
"normal"
,
"poisson"
])
class
TransformerEncoder
(
W2vTransformerEncoder
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
dropout
=
args
.
dropout
self
.
embedding_dim
=
args
.
encoder_embed_dim
self
.
required_seq_len_multiple
=
args
.
required_seq_len_multiple
self
.
use_rel_pos_enc
=
getattr
(
args
,
"use_rel_pos_enc"
,
False
)
self
.
pos_conv
=
nn
.
Conv1d
(
self
.
embedding_dim
,
self
.
embedding_dim
,
kernel_size
=
args
.
conv_pos
,
padding
=
args
.
conv_pos
//
2
,
groups
=
args
.
conv_pos_groups
,
)
dropout
=
0
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
args
.
conv_pos
*
self
.
embedding_dim
))
nn
.
init
.
normal_
(
self
.
pos_conv
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant_
(
self
.
pos_conv
.
bias
,
0
)
self
.
pos_conv
=
nn
.
utils
.
weight_norm
(
self
.
pos_conv
,
name
=
"weight"
,
dim
=
2
)
self
.
pos_conv
=
nn
.
Sequential
(
self
.
pos_conv
,
SamePad
(
args
.
conv_pos
),
nn
.
GELU
())
layers
=
[]
for
_
in
range
(
args
.
encoder_layers
):
layer
=
TransformerSentenceEncoderLayer
(
embedding_dim
=
self
.
embedding_dim
,
ffn_embedding_dim
=
args
.
encoder_ffn_embed_dim
,
num_attention_heads
=
args
.
encoder_attention_heads
,
dropout
=
self
.
dropout
,
attention_dropout
=
args
.
attention_dropout
,
activation_dropout
=
args
.
activation_dropout
,
activation_fn
=
args
.
activation_fn
,
layer_norm_first
=
args
.
layer_norm_first
,
has_relative_attention_bias
=
self
.
use_rel_pos_enc
,
)
if
args
.
checkpoint_activations
:
layer
=
fsdp_wrap
(
layer
)
layer
=
checkpoint_wrapper
(
layer
)
layers
.
append
(
layer
)
self
.
layers
=
nn
.
ModuleList
(
layers
)
self
.
layer_norm_first
=
args
.
layer_norm_first
self
.
layer_norm
=
LayerNorm
(
self
.
embedding_dim
)
self
.
layerdrop
=
args
.
encoder_layerdrop
if
self
.
use_rel_pos_enc
:
self
.
pos_emb
=
RelativePositionalEncoding
(
args
.
encoder_embed_dim
//
args
.
encoder_attention_heads
,
160
)
self
.
apply
(
init_bert_params
)
def
forward
(
self
,
x
,
padding_mask
=
None
,
layer
=
None
):
x
,
layer_results
=
self
.
extract_features
(
x
,
padding_mask
,
layer
)
if
self
.
layer_norm_first
and
layer
is
None
:
x
=
self
.
layer_norm
(
x
)
return
x
,
layer_results
def
extract_features
(
self
,
x
,
padding_mask
=
None
,
tgt_layer
=
None
):
if
padding_mask
is
not
None
:
x
=
index_put
(
x
,
padding_mask
,
0
)
x_conv
=
self
.
pos_conv
(
x
.
transpose
(
1
,
2
))
x_conv
=
x_conv
.
transpose
(
1
,
2
)
x
=
x
+
x_conv
if
not
self
.
layer_norm_first
:
x
=
self
.
layer_norm
(
x
)
# pad to the sequence length dimension
x
,
pad_length
=
pad_to_multiple
(
x
,
self
.
required_seq_len_multiple
,
dim
=-
2
,
value
=
0
)
if
pad_length
>
0
and
padding_mask
is
None
:
padding_mask
=
x
.
new_zeros
((
x
.
size
(
0
),
x
.
size
(
1
)),
dtype
=
torch
.
bool
)
padding_mask
[:,
-
pad_length
:]
=
True
else
:
padding_mask
,
_
=
pad_to_multiple
(
padding_mask
,
self
.
required_seq_len_multiple
,
dim
=-
1
,
value
=
True
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
use_rel_pos_enc
:
x_len
=
x
.
shape
[
0
]
pos_seq
=
torch
.
arange
(
0
,
x_len
).
long
().
to
(
x
.
device
)
pos_seq
=
pos_seq
[:,
None
]
-
pos_seq
[
None
,
:]
pos_k
,
pos_v
=
self
.
pos_emb
(
pos_seq
)
else
:
pos_k
=
None
layer_results
=
[]
r
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
dropout_probability
=
np
.
random
.
random
()
if
not
self
.
training
or
(
dropout_probability
>
self
.
layerdrop
):
x
,
z
=
layer
(
x
,
self_attn_padding_mask
=
padding_mask
,
need_weights
=
False
,
pos_bias
=
pos_k
)
if
tgt_layer
is
not
None
:
# unpad if needed
if
pad_length
>
0
:
layer_results
.
append
(
(
x
[:
-
pad_length
],
z
[:,
:
-
pad_length
,
:
-
pad_length
]
if
z
is
not
None
else
z
,
)
)
else
:
layer_results
.
append
((
x
,
z
))
if
i
==
tgt_layer
:
r
=
x
break
if
r
is
not
None
:
x
=
r
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
# undo paddding
if
pad_length
>
0
:
x
=
x
[:,
:
-
pad_length
]
return
x
,
layer_results
class
TransformerSentenceEncoderLayer
(
nn
.
Module
):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def
__init__
(
self
,
embedding_dim
:
float
=
768
,
ffn_embedding_dim
:
float
=
3072
,
num_attention_heads
:
float
=
8
,
dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
activation_dropout
:
float
=
0.1
,
activation_fn
:
str
=
"relu"
,
layer_norm_first
:
bool
=
False
,
has_relative_attention_bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
# Initialize parameters
self
.
embedding_dim
=
embedding_dim
self
.
dropout
=
dropout
self
.
activation_dropout
=
activation_dropout
# Initialize blocks
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation_fn
)
self
.
self_attn
=
MultiheadAttention
(
self
.
embedding_dim
,
num_attention_heads
,
dropout
=
attention_dropout
,
self_attention
=
True
,
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
self
.
activation_dropout
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
layer_norm_first
=
layer_norm_first
# layer norm associated with the self attention layer
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embedding_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embedding_dim
,
ffn_embedding_dim
)
self
.
fc2
=
nn
.
Linear
(
ffn_embedding_dim
,
self
.
embedding_dim
)
# layer norm associated with the position wise feed-forward NN
self
.
final_layer_norm
=
LayerNorm
(
self
.
embedding_dim
)
if
has_relative_attention_bias
:
self
.
norm_k
=
LayerNorm
(
self
.
embedding_dim
//
num_attention_heads
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
self_attn_mask
:
torch
.
Tensor
=
None
,
self_attn_padding_mask
:
torch
.
Tensor
=
None
,
need_weights
:
bool
=
False
,
att_args
=
None
,
pos_bias
=
None
,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual
=
x
if
self
.
layer_norm_first
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
pos_bias
is
not
None
:
pos_bias
=
self
.
norm_k
(
pos_bias
)
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
self_attn_padding_mask
,
attn_mask
=
self_attn_mask
,
position_bias
=
pos_bias
,
)
x
=
self
.
dropout1
(
x
)
x
=
residual
+
x
residual
=
x
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout3
(
x
)
x
=
residual
+
x
else
:
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
self_attn_padding_mask
,
position_bias
=
pos_bias
,
)
x
=
self
.
dropout1
(
x
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout3
(
x
)
x
=
residual
+
x
x
=
self
.
final_layer_norm
(
x
)
return
x
,
attn
SpeechLM/speechlm/scripts/pretrain_speechlm/base_speechlmh.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechLM-H Base model #
# ####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> [mount=
${
PWD
}
] [world_size=32] [update_freq=1]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
mount
=
$3
world_size
=
$4
update_freq
=
$5
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/base_speechlmh_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechlm/config/pretrain
\
--config-name
speechlm_base_librispeech
\
common.user_dir
=
$CODE_ROOT
/speechlm
\
\
task.labels
=
'["km"]'
\
model.label_rate
=
50
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
dataset.train_subset
=
\"
train_960+train_text.km-ltr
\"
\
dataset.valid_subset
=
\"
dev_clean+dev_clean.km-ltr
\"
\
dataset.num_workers
=
0
\
dataset.max_tokens
=
1400000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
pretrain
# data_dir="/mnt/default/v-ziqzhang/data/stbert/data/librispeech/hubert_release_iter2_layer9_kmeans/local"
# text_data_dir="/mnt/default/v-ziqzhang/dataset/LibriLM/from_fastT2U/bin-idx"
SpeechLM/speechlm/scripts/pretrain_speechlm/base_speechlmp.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechLM-P Base model #
# ####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> [mount=
${
PWD
}
] [world_size=32] [update_freq=1]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
mount
=
$3
world_size
=
$4
update_freq
=
$5
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/base_speechlmp_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechlm/config/pretrain
\
--config-name
speechlm_base_librispeech
\
common.user_dir
=
$CODE_ROOT
/speechlm
\
\
task.labels
=
'["phn"]'
\
model.label_rate
=
100
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
dataset.train_subset
=
\"
train_960+train_text.phn-ltr
\"
\
dataset.valid_subset
=
\"
dev_clean+dev_clean.phn-ltr
\"
\
dataset.num_workers
=
0
\
dataset.max_tokens
=
1400000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
pretrain
# data_dir="/stdblob/users/v-ziqzhang/dataset/LibriLM/phn2char_sanych/tri4b_mono_label"
# text_data_dir="/stdblob/users/v-ziqzhang/dataset/LibriLM/phn2char_sanych/filt2k_sil025_m5std25_sil14_spn32/bin-idx"
SpeechLM/speechlm/scripts/pretrain_speechlm/large_speechlmp.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechLM-P Large model #
# ####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> [mount=
${
PWD
}
] [world_size=32] [update_freq=4]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
mount
=
$3
world_size
=
$4
update_freq
=
$5
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
4
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/large_speechlmp_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechlm/config/pretrain
\
--config-name
speechlm_large_librilight
\
common.user_dir
=
$CODE_ROOT
/speechlm
\
\
task.labels
=
'["phn"]'
\
model.label_rate
=
50
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
dataset.train_subset
=
\"
train_60k+train_text.phn-ltr
\"
\
dataset.valid_subset
=
\"
dev_clean+dev_clean.phn-ltr
\"
\
dataset.num_workers
=
1
\
dataset.max_tokens
=
900000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.fp16_scale_tolerance
=
0.1
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
pretrain
# data_dir="/stdblob/users/v-ziqzhang/dataset/librilight/chunkdata"
# text_data_dir="/stdblob/users/v-ziqzhang/dataset/LibriLM/phn2char_sanych/filt2k_sil025_m5std25_sil14_spn32/bin-idx"
SpeechLM/speechlm/scripts/tokenizer_fastT2U/generate.sh
0 → 100644
View file @
12c90639
#####################################
# Fast Text2Unit Model #
#####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<model_path> <gen_set> [outdir={gen_set%/*}]"
&&
exit
0
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
model_path
=
$1
src_dir
=
${
model_path
%/*
}
cpt
=
${
model_path
##*/
}
cpt
=
${
cpt
%.*
}
gen_set
=
$2
outdir
=
$3
DATA_DIR
=
${
gen_set
%/*
}
gen_set
=
${
gen_set
##*/
}
[
-z
$outdir
]
&&
outdir
=
${
DATA_DIR
}
CODE_ROOT
=
${
PWD
}
nj
=
4
for
rank
in
$(
seq
0
$((
nj-1
))
)
;
do
results_path
=
$outdir
/pseudo_
${
gen_set
}
/
${
rank
}
[
!
-d
$results_path
]
&&
mkdir
-p
$results_path
echo
"
$model_path
"
>
$results_path
/model.record
python
$CODE_ROOT
/speechlm/generate_unit.py
$DATA_DIR
\
--user-dir
$CODE_ROOT
/speechlm
\
--config-yaml
config_generate.yaml
\
--path
${
model_path
}
\
--task
fast_text_to_unit
\
--gen-subset
$gen_set
\
\
--beam
1
\
--max-tokens
10000
\
--results-path
$results_path
\
--scoring
sacrebleu
\
--skip-invalid-size-inputs-valid-test
\
--distributed-world-size
$nj
--distributed-rank
${
rank
}
\
&
done
wait
SpeechLM/speechlm/scripts/tokenizer_fastT2U/infer.sh
0 → 100644
View file @
12c90639
#####################################
# Fast Text2Unit Model #
#####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<model_path> <gen_set> "
&&
exit
0
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
model_path
=
$1
src_dir
=
${
model_path
%/*
}
cpt
=
${
model_path
##*/
}
cpt
=
${
cpt
%.*
}
gen_set
=
$2
DATA_DIR
=
${
gen_set
%/*
}
gen_set
=
${
gen_set
##*/
}
outdir
=
$src_dir
/decode_
${
cpt
}
CODE_ROOT
=
${
PWD
}
for
subset
in
${
gen_set
//,/
}
;
do
results_path
=
$outdir
/phone2unit_
${
subset
}
[
!
-d
$results_path
]
&&
mkdir
-p
$results_path
python
$CODE_ROOT
/speechlm/generate_unit.py
$DATA_DIR
\
--user-dir
$CODE_ROOT
/speechlm
\
--config-yaml
config.yaml
\
--path
${
model_path
}
\
--task
fast_text_to_unit
\
--gen-subset
$subset
\
\
--beam
1
\
--max-tokens
10000
\
--results-path
$results_path
\
--scoring
sacrebleu
echo
$results_path
tail
-n
1
$results_path
/generate-
*
.txt
sleep
1s
done
# --distributed-world-size 1000 --distributed-rank 0 \
SpeechLM/speechlm/scripts/tokenizer_fastT2U/train_s_5e-4.sh
0 → 100644
View file @
12c90639
#####################################
# Fast Text2Unit Model #
#####################################
[
$#
-lt
1
]
&&
echo
"Usage:
$0
<data_dir> [mount] [world_size=4] [update_freq=1]"
&&
exit
0
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
mount
=
$2
world_size
=
$3
update_freq
=
$4
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
4
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
$mount
/exp/fast_text2unit/small_lr5e-4_tristage_ls0.1_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
fairseq-train
${
DATA_DIR
}
--save-dir
${
MODEL_DIR
}
\
--config-yaml
config.yaml
\
--user-dir
$CODE_ROOT
/speechlm
\
--train-subset
train_100
--valid-subset
dev_clean
\
--num-workers
4
--max-tokens
20000
\
--distributed-world-size
${
world_size
}
--update-freq
${
update_freq
}
\
\
--task
fast_text_to_unit
--criterion
fasttext2unit_criterion
--arch
fasttext2unit_s
\
--label-smoothing
0.1
\
\
--clip-norm
5.0
--n-frames-per-step
1
\
--dropout
0.1
--attention-dropout
0.1
\
--optimizer
adam
--lr
5e-4
--lr-scheduler
tri_stage
--phase-ratio
[
0.3,0.0,0.7]
--max-update
10000
\
--seed
1
--best-checkpoint-metric
accuracy
--maximize-best-checkpoint-metric
\
\
--save-interval
2
\
--tensorboard-logdir
${
MODEL_DIR
}
\
--fp16
--find-unused-parameters
\
|
tee
${
MODEL_DIR
}
/train.log
# DATA_DIR=/mnt/default/v-ziqzhang/dataset/librispeech_phone2unit/phone2unit
SpeechLM/speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechLM Base model #
# ####################################
[
$#
-lt
3
]
&&
echo
"Usage:
$0
<model_path> <data_dir> <cpt_tag> [mount=
${
PWD
}
] [world_size=8] [update_freq=1]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
w2v_path
=
$1
DATA_DIR
=
$2
cpt
=
$3
mount
=
$4
world_size
=
$5
update_freq
=
$6
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
8
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
exp_name
=
${
w2v_path
%/*
}
exp_name
=
${
exp_name
##*/
}
MODEL_DIR
=
"
${
mount
}
/exp/finetune_asr/
$exp_name
/ctc30k_from_
${
cpt
}
_bz1.6m_lr1e-5"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechlm/config/finetune
\
--config-name
speechlm_base_100h
\
common.user_dir
=
$CODE_ROOT
/speechlm
\
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
model.w2v_path
=
${
w2v_path
}
\
\
optimization.lr
=[
0.00001]
\
optimization.max_update
=
30000
\
dataset.max_tokens
=
1600000
\
optimization.update_freq
=[
${
update_freq
}
]
\
distributed_training.distributed_world_size
=
${
world_size
}
\
\
dataset.train_subset
=
"train_clean_100"
\
dataset.valid_subset
=
"dev_other"
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
${
exp_name
}
# model_path=/mnt/default/v-ziqzhang/data/speechulm/exp/base/base_speechlmp_32gpu_1accum/checkpoint_298_400000.pt
# data_dir=/home/v-ziqzhang/dataset/LibriSpeech/asr
SpeechLM/speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechLM Large model #
# ####################################
[
$#
-lt
3
]
&&
echo
"Usage:
$0
<model_path> <data_dir> <cpt_tag> [mount=
${
PWD
}
] [world_size=8] [update_freq=4]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechLM
]
&&
echo
"Error: dir not match! Switch to SpeechLM/ and run it again!"
&&
exit
1
w2v_path
=
$1
DATA_DIR
=
$2
cpt
=
$3
mount
=
$4
world_size
=
$5
update_freq
=
$6
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
8
[
-z
$update_freq
]
&&
update_freq
=
4
CODE_ROOT
=
${
PWD
}
exp_name
=
${
w2v_path
%/*
}
exp_name
=
${
exp_name
##*/
}
MODEL_DIR
=
"
${
mount
}
/exp/finetune_asr/
$exp_name
/ctc200k_from_
${
cpt
}
_bz3.6m_lr1e-5"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechlm/config/finetune
\
--config-name
speechlm_large_960h
\
common.user_dir
=
$CODE_ROOT
/speechlm
\
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
model.w2v_path
=
${
w2v_path
}
\
\
optimization.lr
=[
0.00001]
\
optimization.max_update
=
200000
\
dataset.max_tokens
=
900000
\
optimization.update_freq
=[
${
update_freq
}
]
\
distributed_training.distributed_world_size
=
${
world_size
}
\
\
dataset.train_subset
=
"train_960"
\
dataset.valid_subset
=
"dev_other"
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
${
exp_name
}
# model_path=/mnt/default/v-ziqzhang/data/speechulm/exp/large/large_speechlmp_32gpu_4accum/checkpoint_31_400000.pt
# data_dir=/home/v-ziqzhang/dataset/LibriSpeech/asr
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment