Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
0f6d59d4
Commit
0f6d59d4
authored
Feb 26, 2024
by
Yoach Lacombe
Browse files
latest changes
parent
11fcc066
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
739 additions
and
109 deletions
+739
-109
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+18
-12
example_configs/librispeech_tts_r_dummy.json
example_configs/librispeech_tts_r_dummy.json
+74
-0
init_dummy_model.py
init_dummy_model.py
+10
-10
run_stable_speech_training.py
run_stable_speech_training.py
+573
-52
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+3
-3
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+61
-32
No files found.
example_configs/librispeech_tts_r.json
View file @
0f6d59d4
...
...
@@ -24,11 +24,11 @@
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
100
0
,
"max_eval_samples"
:
20
0
,
"max_train_samples"
:
2
0
,
"max_eval_samples"
:
1
0
,
"max_duration_in_seconds"
:
2
0
,
"max_duration_in_seconds"
:
3
0
,
"min_duration_in_seconds"
:
1.0
,
"add_audio_samples_to_wandb"
:
true
,
...
...
@@ -36,30 +36,36 @@
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
20
49
,
"pad_token_id"
:
20
50
,
"decoder_start_token_id"
:
2048
,
"do_train"
:
true
,
"num_train_epochs"
:
1
,
"num_train_epochs"
:
1
20
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
tru
e
,
"per_device_train_batch_size"
:
8
,
"learning_rate"
:
1e-
6
,
"gradient_checkpointing"
:
fals
e
,
"per_device_train_batch_size"
:
2
,
"learning_rate"
:
1e-
3
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.9
5
,
"adam_beta2"
:
0.9
99
,
"weight_decay"
:
0.1
,
"logging_steps"
:
25
,
"lr_scheduler_type"
:
"cosine"
,
"warmup_ratio"
:
0.1
,
"logging_steps"
:
1
,
"freeze_text_encoder"
:
true
,
"do_eval"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"epoch"
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
600
,
"per_device_eval_batch_size"
:
8
,
"generation_max_length"
:
400
,
"fp16"
:
tru
e
,
"fp16"
:
fals
e
,
"seed"
:
456
,
"dataloader_num_workers"
:
8
...
...
example_configs/librispeech_tts_r_dummy.json
0 → 100644
View file @
0f6d59d4
{
"model_name_or_path"
:
"/home/yoach/dataspeech/artefacts/tiny-model/"
,
"feature_extractor_name"
:
"facebook/encodec_24khz"
,
"description_tokenizer_name"
:
"t5-base"
,
"prompt_tokenizer_name"
:
"t5-base"
,
"push_to_hub"
:
false
,
"hub_model_id"
:
"stable-speech-mini"
,
"report_to"
:
[
"wandb"
],
"overwrite_output_dir"
:
true
,
"output_dir"
:
"/home/yoach/dataspeech/artefacts/training/"
,
"train_dataset_name"
:
"blabble-io/libritts_r"
,
"train_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated"
,
"train_dataset_config_name"
:
"clean"
,
"train_split_name"
:
"train.clean.360"
,
"eval_dataset_name"
:
"blabble-io/libritts_r"
,
"eval_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated"
,
"eval_dataset_config_name"
:
"clean"
,
"eval_split_name"
:
"train.clean.360"
,
"target_audio_column_name"
:
"audio"
,
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
12
,
"max_eval_samples"
:
12
,
"max_duration_in_seconds"
:
30
,
"min_duration_in_seconds"
:
1.0
,
"add_audio_samples_to_wandb"
:
true
,
"id_column_name"
:
"id"
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
2050
,
"decoder_start_token_id"
:
2048
,
"do_train"
:
true
,
"num_train_epochs"
:
20
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
false
,
"per_device_train_batch_size"
:
3
,
"learning_rate"
:
1e-3
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
"weight_decay"
:
0.1
,
"lr_scheduler_type"
:
"cosine"
,
"warmup_ratio"
:
0.1
,
"freeze_text_encoder"
:
true
,
"do_eval"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
10
,
"per_device_eval_batch_size"
:
3
,
"generation_max_length"
:
400
,
"do_sample"
:
true
,
"logging_steps"
:
15
,
"dtype"
:
"float32"
,
"seed"
:
456
,
"dataloader_num_workers"
:
8
}
init_dummy_model.py
View file @
0f6d59d4
...
...
@@ -4,16 +4,16 @@ from transformers import AutoConfig
decoder_config
=
StableSpeechDecoderConfig
(
max_position_embeddings
=
2048
,
num_hidden_layers
=
2
,
ffn_dim
=
256
,
num_attention_heads
=
4
,
num_hidden_layers
=
4
,
ffn_dim
=
512
,
num_attention_heads
=
8
,
layerdrop
=
0.0
,
use_cache
=
True
,
activation_function
=
"gelu"
,
hidden_size
=
256
,
dropout
=
0.
1
,
attention_dropout
=
0.
1
,
activation_dropout
=
0.
1
,
hidden_size
=
512
,
dropout
=
0.
0
,
attention_dropout
=
0.
0
,
activation_dropout
=
0.
0
,
)
# TODO: ?? how to make it stop ?
...
...
@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
2048
model
.
generation_config
.
pad_token_id
=
20
49
model
.
generation_config
.
pad_token_id
=
20
50
model
.
generation_config
.
eos_token_id
=
2049
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
do_sample
=
True
model
.
generation_config
.
guidance_scale
=
3.0
model
.
generation_config
.
do_sample
=
False
#
True
model
.
generation_config
.
guidance_scale
=
1
#
3.0
model
.
save_pretrained
(
"/home/yoach/dataspeech/artefacts/tiny-model/"
)
\ No newline at end of file
run_stable_speech_training.py
View file @
0f6d59d4
This diff is collapsed.
Click to expand it.
stable_speech/configuration_stable_speech.py
View file @
0f6d59d4
...
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 20
49
):
vocab_size (`int`, *optional*, defaults to 20
50
):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
...
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def
__init__
(
self
,
vocab_size
=
20
49
,
# vocab size = 2048 (encodec vocab size) +
1
(
e
os
token
)
vocab_size
=
20
50
,
# vocab size = 2048 (encodec vocab size) +
2
(
b
os
, eos
)
max_position_embeddings
=
2048
,
num_hidden_layers
=
24
,
ffn_dim
=
4096
,
...
...
@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor
=
0.02
,
scale_embedding
=
False
,
num_codebooks
=
4
,
pad_token_id
=
20
49
,
pad_token_id
=
20
50
,
bos_token_id
=
2048
,
eos_token_id
=
2049
,
tie_word_embeddings
=
False
,
...
...
stable_speech/modeling_stable_speech.py
View file @
0f6d59d4
...
...
@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self
.
num_codebooks
=
config
.
num_codebooks
self
.
embed_scale
=
math
.
sqrt
(
config
.
hidden_size
)
if
config
.
scale_embedding
else
1.0
embed_dim
=
config
.
vocab_size
+
1
# TODO: not right dim
embed_dim
=
config
.
vocab_size
+
1
# + 1 for pad token id
self
.
embed_tokens
=
nn
.
ModuleList
(
[
nn
.
Embedding
(
embed_dim
,
config
.
hidden_size
)
for
_
in
range
(
config
.
num_codebooks
)]
)
...
...
@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
...
@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
# TODO: delay_pattern_mask
Returns:
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
...
@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss
=
None
if
labels
is
not
None
:
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits
=
lm_logits
[:,:,
-
labels
.
shape
[
1
]:]
loss_fct
=
CrossEntropyLoss
()
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
loss
=
loss_fct
(
logits
.
transpose
(
1
,
3
),
labels
)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask
=
(
labels
!=
-
100
)
# per codebook cross-entropy
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
codebook_logits
=
logits
[:,
codebook
].
contiguous
().
view
(
-
1
,
logits
.
shape
[
-
1
])
codebook_mask
=
mask
[...,
codebook
].
contiguous
().
view
(
-
1
)
codebook_labels
=
labels
[...,
codebook
].
contiguous
().
view
(
-
1
)
codebook_loss
=
loss_fct
(
codebook_logits
[
codebook_mask
],
codebook_labels
[
codebook_mask
])
loss
+=
codebook_loss
loss
=
loss
/
self
.
config
.
num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits
=
lm_logits
.
reshape
(
-
1
,
*
lm_logits
.
shape
[
2
:])
...
...
@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if
delay_pattern_mask
is
None
:
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
bos_token_id
=
self
.
generation_config
.
decoder_start
_token_id
,
eos
_token_id
=
self
.
generation_config
.
eos
_token_id
,
bos_token_id
=
self
.
generation_config
.
bos
_token_id
,
pad
_token_id
=
self
.
generation_config
.
pad
_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
)
...
...
@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
# Ignore copy
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
eos
_token_id
:
int
,
max_length
:
int
=
None
):
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad
_token_id
:
int
,
max_length
:
int
=
None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1,
E
,
E
,
E
]
- [B, B, -1, -1, -1, -1,
E
,
E
]
- [B, B, B, -1, -1, -1, -1,
E
]
- [B, -1, -1, -1, -1,
P
,
P
,
P
]
- [B, B, -1, -1, -1, -1,
P
,
P
]
- [B, B, B, -1, -1, -1, -1,
P
]
- [B, B, B, B, -1, -1, -1, -1]
where
B
is the
BOS token id, E is the EOS
token id and -1 indicates that the token is valid for prediction. If we include
where
P
is the
special padding
token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1,
E
,
E
,
E
]
- [B, B, c, d, -1, -1,
E
,
E
]
- [B, B, B, e, f, -1, -1,
E
]
- [B, a, b, -1, -1,
P
,
P
,
P
]
- [B, B, c, d, -1, -1,
P
,
P
]
- [B, B, B, e, f, -1, -1,
P
]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
...
...
@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
eos
_token_id
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
pad
_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
...
...
@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
bos_token_id
=
generation_config
.
decoder_start
_token_id
,
eos
_token_id
=
generation_config
.
eos
_token_id
,
bos_token_id
=
generation_config
.
bos
_token_id
,
pad
_token_id
=
generation_config
.
pad
_token_id
,
max_length
=
generation_config
.
max_length
,
)
...
...
@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
...
@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# TODO: verify prompt_attention_mask
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
# TODO: verify it does what's expected
decoder_input_ids
=
shift_tokens_right
(
labels
.
transpose
(
1
,
2
)
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
.
transpose
(
1
,
2
)
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
audio_encoder_outputs
=
self
.
audio_encoder
(
...
...
@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values
=
past_key_values
,
return_dict
=
return_dict
,
labels
=
labels
,
label_delay_pattern_mask
=
label_delay_pattern_mask
,
**
kwargs_decoder
,
)
...
...
@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
decoder_delay_pattern_mask
is
None
:
decoder_input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
decoder_input_ids
,
bos_token_id
=
self
.
generation_config
.
decoder_start
_token_id
,
eos
_token_id
=
self
.
generation_config
.
eos
_token_id
,
bos_token_id
=
self
.
generation_config
.
bos
_token_id
,
pad
_token_id
=
self
.
generation_config
.
pad
_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
)
...
...
@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return
model_kwargs
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
return
shift_tokens_right
(
labels
.
transpose
(
1
,
2
)
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
.
transpose
(
1
,
2
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
...
...
@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
bos_token_id
=
generation_config
.
decoder_start
_token_id
,
eos
_token_id
=
generation_config
.
eos
_token_id
,
bos_token_id
=
generation_config
.
bos
_token_id
,
pad
_token_id
=
generation_config
.
pad
_token_id
,
max_length
=
generation_config
.
max_length
,
)
# stash the delay mask so that we don't have to recompute in each forward pass
...
...
@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
eos_token_id
)].
reshape
(
# TODO: probably won't work...
output_ids
=
output_ids
[(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
pad_token_id
)].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
...
...
@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
audio_scales
=
audio_scales
,
).
audio_values
).
audio_values
.
squeeze
(
1
)
else
:
output_values
=
[]
for
sample_id
in
range
(
batch_size
):
sample
=
output_ids
[:,
sample_id
]
sample_mask
=
(((
sample
==
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample
=
sample
[:,
:,
sample_mask
]
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
sample_mask
=
(((
sample
==
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)
|
(
sample
==
generation_config
.
pad_token_id
)).
sum
(
dim
=
(
0
,
1
))
==
0
)
if
sample_mask
.
sum
()
>
0
:
sample
=
sample
[:,
:,
sample_mask
]
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
else
:
output_values
.
append
(
torch
.
zeros
((
1
,
1
,
1
)).
to
(
self
.
device
))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
.
squeeze
(
1
)
if
generation_config
.
return_dict_in_generate
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment