Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
3170ac02
Commit
3170ac02
authored
Apr 24, 2024
by
Dan Lyth
Browse files
adding eval.py and simple train.py, re-instating run_parler_tts_training.py
parent
09df5026
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1063 additions
and
314 deletions
+1063
-314
parler_tts/eval.py
parler_tts/eval.py
+35
-0
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+1018
-0
training/train.py
training/train.py
+10
-314
No files found.
parler_tts/eval.py
0 → 100644
View file @
3170ac02
import
torch
import
evaluate
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
def
clap_similarity
(
clap_model_name_or_path
,
texts
,
audios
,
device
):
clap
=
AutoModel
.
from_pretrained
(
clap_model_name_or_path
)
clap_processor
=
AutoProcessor
.
from_pretrained
(
clap_model_name_or_path
)
clap_inputs
=
clap_processor
(
text
=
texts
,
audios
=
audios
,
padding
=
True
,
return_tensors
=
"pt"
).
to
(
device
)
clap
.
to
(
device
)
with
torch
.
no_grad
():
text_features
=
clap
.
get_text_features
(
clap_inputs
[
"input_ids"
],
attention_mask
=
clap_inputs
.
get
(
"attention_mask"
,
None
)
)
audio_features
=
clap
.
get_audio_features
(
clap_inputs
[
"input_features"
])
cosine_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
audio_features
,
text_features
,
dim
=
1
,
eps
=
1e-8
)
clap
.
to
(
"cpu"
)
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
metric
=
evaluate
.
load
(
"wer"
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
transcriptions
=
asr_pipeline
(
[{
"raw"
:
audio
,
"sampling_rate"
:
sampling_rate
}
for
audio
in
audios
],
batch_size
=
int
(
per_device_eval_batch_size
),
)
word_error
=
100
*
metric
.
compute
(
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
\ No newline at end of file
training/run_parler_tts_training.py
0 → 100644
View file @
3170ac02
This diff is collapsed.
Click to expand it.
training/train.py
View file @
3170ac02
...
@@ -63,7 +63,7 @@ from parler_tts import (
...
@@ -63,7 +63,7 @@ from parler_tts import (
from
parler_tts.utils
import
get_last_checkpoint
,
rotate_checkpoints
,
log_pred
,
log_metric
from
parler_tts.utils
import
get_last_checkpoint
,
rotate_checkpoints
,
log_pred
,
log_metric
from
parler_tts.arguments
import
ModelArguments
,
DataTrainingArguments
,
ParlerTTSTrainingArguments
from
parler_tts.arguments
import
ModelArguments
,
DataTrainingArguments
,
ParlerTTSTrainingArguments
from
parler_tts.data
import
load_multiple_datasets
,
DataCollatorParlerTTSWithPadding
,
DataCollatorEncodecWithPadding
from
parler_tts.data
import
DataCollatorParlerTTSWithPadding
,
DataCollatorEncodecWithPadding
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -104,7 +104,7 @@ def main():
...
@@ -104,7 +104,7 @@ def main():
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
"longest"
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
"longest"
#
###### A. P
reparation
#
Accelerator p
reparation
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
if
training_args
.
torch_compile
:
if
training_args
.
torch_compile
:
# TODO(YL): add more compile modes?
# TODO(YL): add more compile modes?
...
@@ -182,7 +182,7 @@ def main():
...
@@ -182,7 +182,7 @@ def main():
set_seed
(
training_args
.
seed
)
set_seed
(
training_args
.
seed
)
num_workers
=
data_args
.
preprocessing_num_workers
num_workers
=
data_args
.
preprocessing_num_workers
# 1. First, let
t
's instantiate the feature extractor, tokenizers and model
# 1. First, let's instantiate the feature extractor
(DAC)
, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
# one local process can concurrently download model & vocab.
...
@@ -222,79 +222,7 @@ def main():
...
@@ -222,79 +222,7 @@ def main():
description_tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
description_tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
# 2. Now, let's load the dataset
# 2. Now, let's load the dataset
# TODO add MDS dataset loading here
if
data_args
.
save_to_disk
is
not
None
:
os
.
makedirs
(
data_args
.
save_to_disk
,
exist_ok
=
True
)
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed
=
len
(
os
.
listdir
(
data_args
.
save_to_disk
))
>
0
if
dataset_was_precomputed
:
vectorized_datasets
=
datasets
.
load_from_disk
(
data_args
.
save_to_disk
)
else
:
raw_datasets
=
DatasetDict
()
columns_to_keep
=
{
"target_audio_column_name"
:
data_args
.
target_audio_column_name
,
"prompt_column_name"
:
data_args
.
prompt_column_name
,
}
if
data_args
.
description_column_name
is
not
None
:
columns_to_keep
[
"description_column_name"
]
=
data_args
.
description_column_name
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
train_metadata_dataset_name
,
splits
=
data_args
.
train_split_name
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
prompt_column_name
=
data_args
.
prompt_column_name
,
audio_column_name
=
data_args
.
target_audio_column_name
,
sampling_rate
=
sampling_rate
,
logger
=
logger
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
for
key
in
columns_to_keep
:
if
columns_to_keep
[
key
]
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--
{
key
}
'
{
columns_to_keep
[
key
]
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
f
" Make sure to set `--
{
key
}
` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
raw_datasets
[
"train"
].
select
(
range
(
data_args
.
max_train_samples
))
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
eval_dataset_name
if
data_args
.
eval_dataset_name
else
data_args
.
train_dataset_name
,
data_args
.
eval_dataset_config_name
if
data_args
.
eval_dataset_config_name
else
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
eval_metadata_dataset_name
,
splits
=
data_args
.
eval_split_name
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
prompt_column_name
=
data_args
.
prompt_column_name
,
audio_column_name
=
data_args
.
target_audio_column_name
,
sampling_rate
=
sampling_rate
,
logger
=
logger
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
(
raw_datasets
[
"eval"
].
shuffle
(
seed
=
training_args
.
seed
).
select
(
range
(
data_args
.
max_eval_samples
))
)
# 3. Next, let's load the config.
# 3. Next, let's load the config.
config
=
ParlerTTSConfig
.
from_pretrained
(
config
=
ParlerTTSConfig
.
from_pretrained
(
...
@@ -330,250 +258,17 @@ def main():
...
@@ -330,250 +258,17 @@ def main():
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# TODO add MDS dataset preprocessing here (only thing we'll need is the delay pattern)
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# derive max & min input length for sample rate & max duration
# derive max & min input length for sample rate & max duration
sampling_rate
=
feature_extractor
.
sampling_rate
sampling_rate
=
feature_extractor
.
sampling_rate
max_target_length
=
data_args
.
max_duration_in_seconds
*
sampling_rate
min_target_length
=
data_args
.
min_duration_in_seconds
*
sampling_rate
target_audio_column_name
=
data_args
.
target_audio_column_name
description_column_name
=
data_args
.
description_column_name
prompt_column_name
=
data_args
.
prompt_column_name
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
audio_encoder_pad_token_id
=
config
.
decoder
.
pad_token_id
audio_encoder_eos_token_id
=
config
.
decoder
.
eos_token_id
audio_encoder_bos_token_id
=
model
.
generation_config
.
decoder_start_token_id
max_length
=
model
.
generation_config
.
max_length
num_codebooks
=
model
.
decoder
.
config
.
num_codebooks
bandwidth
=
model_args
.
bandwidth
# Freeze Encoders
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
# TODO check this implementation
# Test all gather - used for warmout and avoiding timeout
test_tensor
=
torch
.
tensor
([
accelerator
.
process_index
],
device
=
accelerator
.
device
)
gathered_tensor
=
accelerator
.
gather
(
test_tensor
)
print
(
"gathered_tensor"
,
gathered_tensor
)
accelerator
.
wait_for_everyone
()
if
not
dataset_was_precomputed
:
# Filter on text length
if
description_column_name
is
not
None
and
data_args
.
max_text_length
is
not
None
:
with
accelerator
.
main_process_first
():
# filter description that is shorter than max_text_length
raw_datasets
=
raw_datasets
.
filter
(
lambda
x
:
len
(
x
)
<
data_args
.
max_text_length
,
num_proc
=
num_workers
,
input_columns
=
[
description_column_name
],
)
# Preprocessing the dataset.
# We need to tokenize the texts.
def
pass_through_processors
(
description
,
prompt
):
batch
=
{}
batch
[
"input_ids"
]
=
description_tokenizer
(
description
.
strip
())[
"input_ids"
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
prompt
.
strip
())[
"input_ids"
]
return
batch
with
accelerator
.
main_process_first
():
# this is a trick to avoid to rewrite the entire audio column which takes ages
vectorized_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
input_columns
=
[
description_column_name
,
prompt_column_name
],
num_proc
=
num_workers
,
desc
=
"preprocess datasets"
,
)
# We use Accelerate to perform distributed inference
# T5 doesn't support fp16
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
# Now we encode the audio labels with encodec.
####### B. Encode audio
logger
.
info
(
"*** Encode target audio with encodec ***"
)
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
if
training_args
.
torch_compile
:
audio_decoder
=
accelerator
.
prepare_model
(
model
.
audio_encoder
,
evaluation_mode
=
True
)
else
:
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
audio_column_name
=
target_audio_column_name
,
feature_extractor_input_name
=
feature_extractor_input_name
,
max_length
=
max_target_length
,
padding
=
padding
,
)
def
apply_audio_decoder
(
batch
):
len_audio
=
batch
.
pop
(
"len_audio"
)
audio_decoder
.
to
(
batch
[
"input_values"
].
device
).
eval
()
with
torch
.
no_grad
():
labels
=
audio_decoder
.
encode
(
**
batch
,
bandwidth
=
bandwidth
)[
"audio_codes"
]
output
=
{}
output
[
"len_audio"
]
=
len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output
[
"labels"
]
=
labels
.
squeeze
(
0
).
transpose
(
1
,
2
)
output
[
"ratio"
]
=
torch
.
ones_like
(
len_audio
)
*
labels
.
shape
[
-
1
]
/
len_audio
.
max
()
return
output
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
raw_datasets
[
split
],
batch_size
=
training_args
.
audio_encoder_per_device_batch_size
,
collate_fn
=
encoder_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
)
data_loader
=
accelerator
.
prepare
(
data_loader
)
all_generated_labels
=
[]
all_lens
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generate_labels
=
apply_audio_decoder
(
batch
)
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
if
accelerator
.
is_main_process
:
lab
=
generate_labels
[
"labels"
].
cpu
().
transpose
(
1
,
2
).
to
(
torch
.
int16
)
rat
=
generate_labels
[
"ratio"
].
cpu
().
squeeze
()
lens
=
generate_labels
[
"len_audio"
].
cpu
().
squeeze
()
lab
=
[
l
[:,
:
int
(
ratio
*
length
)]
for
(
l
,
ratio
,
length
)
in
zip
(
lab
,
rat
,
lens
)]
all_generated_labels
.
extend
(
lab
)
all_lens
.
extend
(
lens
)
# (1, codebooks, seq_len) where seq_len=1
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
if
accelerator
.
is_main_process
:
tmp_labels
=
Dataset
.
from_dict
({
"labels"
:
all_generated_labels
,
"target_length"
:
all_lens
})
tmp_labels
.
save_to_disk
(
os
.
path
.
join
(
data_args
.
temporary_save_to_disk
,
split
),
num_proc
=
1
if
split
==
"eval"
else
data_args
.
preprocessing_num_workers
,
)
accelerator
.
wait_for_everyone
()
del
all_generated_labels
tmp_labels
=
datasets
.
load_from_disk
(
os
.
path
.
join
(
data_args
.
temporary_save_to_disk
,
split
))
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
concatenate_datasets
([
vectorized_datasets
[
split
],
tmp_labels
],
axis
=
1
)
def
postprocess_dataset
(
labels
):
# (1, codebooks, seq_len)
labels
=
torch
.
tensor
(
labels
).
unsqueeze
(
0
)
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
,
delay_pattern_mask
=
build_delay_pattern_mask
(
labels
,
bos_token_id
=
audio_encoder_bos_token_id
,
pad_token_id
=
audio_encoder_eos_token_id
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
num_codebooks
=
num_codebooks
,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels
=
torch
.
where
(
delay_pattern_mask
==
-
1
,
audio_encoder_eos_token_id
,
delay_pattern_mask
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output
=
{
"labels"
:
labels
[:,
1
:]}
return
output
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
data_args
.
preprocessing_num_workers
,
# this one is resource consuming if many processor.
input_columns
=
[
"labels"
],
desc
=
"Postprocessing labeling"
,
)
accelerator
.
free_memory
()
del
generate_labels
,
all_lens
with
accelerator
.
main_process_first
():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_audio_in_length_range
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
if
description_column_name
is
not
None
and
data_args
.
max_description_token_length
is
not
None
:
with
accelerator
.
main_process_first
():
# filter description that is shorter than max_text_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
lambda
x
:
len
(
x
)
<
data_args
.
max_description_token_length
,
num_proc
=
num_workers
,
input_columns
=
[
"input_ids"
],
)
if
data_args
.
max_prompt_token_length
is
not
None
:
with
accelerator
.
main_process_first
():
# filter description that is shorter than max_text_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
lambda
x
:
len
(
x
)
<
data_args
.
max_prompt_token_length
,
num_proc
=
num_workers
,
input_columns
=
[
"prompt_input_ids"
],
)
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
if
accelerator
.
is_main_process
:
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
,
num_proc
=
min
(
data_args
.
preprocessing_num_workers
,
len
(
vectorized_datasets
[
"eval"
])
-
1
),
)
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
audio_max_length
=
None
if
training_args
.
torch_compile
:
audio_max_length
=
max
(
vectorized_datasets
[
"train"
][
"target_length"
])
with
accelerator
.
main_process_first
():
max_sample
=
vectorized_datasets
[
"train"
].
filter
(
lambda
x
:
x
==
audio_max_length
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
audio_max_length
=
torch
.
tensor
(
max_sample
[
0
][
"labels"
]).
shape
[
1
]
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if
data_args
.
preprocessing_only
and
data_args
.
save_to_disk
is
None
:
raise
ValueError
(
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
)
elif
data_args
.
preprocessing_only
:
logger
.
info
(
f
"Data preprocessing finished. Files save at
{
data_args
.
save_to_disk
}
"
)
return
# 6. Next, we can prepare the training.
# 6. Next, we can prepare the training.
# Let's use word CLAP similary and WER metrics as our evaluation metrics
,
# Let's use word CLAP similary and WER metrics as our evaluation metrics
# TODO move this to seperate file
# Define evaluation metrics during training, *i.e.* CLAP similarity
# Define evaluation metrics during training, *i.e.* CLAP similarity
clap
=
AutoModel
.
from_pretrained
(
model_args
.
clap_model_name_or_path
)
clap
=
AutoModel
.
from_pretrained
(
model_args
.
clap_model_name_or_path
)
...
@@ -630,7 +325,7 @@ def main():
...
@@ -630,7 +325,7 @@ def main():
if
training_args
.
max_steps
<
0
:
if
training_args
.
max_steps
<
0
:
num_epochs
=
int
(
training_args
.
num_train_epochs
)
num_epochs
=
int
(
training_args
.
num_train_epochs
)
steps_per_epoch
=
len
(
vectorized_datasets
[
"train"
])
//
(
train_batch_size
*
gradient_accumulation_steps
)
steps_per_epoch
=
len
(
vectorized_datasets
[
"train"
])
//
(
train_batch_size
*
gradient_accumulation_steps
)
# TODO fix this missing variable
total_train_steps
=
steps_per_epoch
*
num_epochs
total_train_steps
=
steps_per_epoch
*
num_epochs
elif
training_args
.
max_steps
>
0
:
elif
training_args
.
max_steps
>
0
:
logger
.
info
(
"max_steps is given, it will override any value given in num_train_epochs"
)
logger
.
info
(
"max_steps is given, it will override any value given in num_train_epochs"
)
...
@@ -673,7 +368,7 @@ def main():
...
@@ -673,7 +368,7 @@ def main():
padding
=
padding
,
padding
=
padding
,
prompt_max_length
=
data_args
.
max_prompt_token_length
,
prompt_max_length
=
data_args
.
max_prompt_token_length
,
description_max_length
=
data_args
.
max_description_token_length
,
description_max_length
=
data_args
.
max_description_token_length
,
audio_max_length
=
audio_max_length
,
audio_max_length
=
audio_max_length
,
# TODO add this variable
)
)
# Prepare everything with accelerate
# Prepare everything with accelerate
...
@@ -869,6 +564,7 @@ def main():
...
@@ -869,6 +564,7 @@ def main():
resume_step
=
None
resume_step
=
None
for
batch
in
train_dataloader
:
for
batch
in
train_dataloader
:
breakpoint
()
with
accelerator
.
accumulate
(
model
):
with
accelerator
.
accumulate
(
model
):
loss
,
train_metric
=
train_step
(
batch
,
accelerator
,
autocast_kwargs
)
loss
,
train_metric
=
train_step
(
batch
,
accelerator
,
autocast_kwargs
)
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment