Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
105 additions
and
105 deletions
+105
-105
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py
...h_projects/seq2seq-distillation/_test_seq2seq_examples.py
+49
-49
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples_multi_gpu.py
.../seq2seq-distillation/_test_seq2seq_examples_multi_gpu.py
+20
-20
examples/research_projects/seq2seq-distillation/finetune.py
examples/research_projects/seq2seq-distillation/finetune.py
+6
-6
examples/research_projects/seq2seq-distillation/make_student.py
...es/research_projects/seq2seq-distillation/make_student.py
+5
-5
examples/research_projects/seq2seq-distillation/run_eval.py
examples/research_projects/seq2seq-distillation/run_eval.py
+1
-1
examples/research_projects/seq2seq-distillation/utils.py
examples/research_projects/seq2seq-distillation/utils.py
+1
-1
examples/research_projects/tapex/wikisql_utils.py
examples/research_projects/tapex/wikisql_utils.py
+2
-2
examples/research_projects/visual_bert/extracting_data.py
examples/research_projects/visual_bert/extracting_data.py
+1
-1
examples/research_projects/visual_bert/modeling_frcnn.py
examples/research_projects/visual_bert/modeling_frcnn.py
+1
-1
examples/research_projects/vqgan-clip/VQGAN_CLIP.py
examples/research_projects/vqgan-clip/VQGAN_CLIP.py
+2
-2
examples/research_projects/vqgan-clip/loaders.py
examples/research_projects/vqgan-clip/loaders.py
+1
-1
examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py
...les/research_projects/wav2vec2/test_wav2vec2_deepspeed.py
+1
-1
examples/research_projects/xtreme-s/run_xtreme_s.py
examples/research_projects/xtreme-s/run_xtreme_s.py
+1
-1
examples/tensorflow/benchmarking/plot_csv_file.py
examples/tensorflow/benchmarking/plot_csv_file.py
+3
-3
examples/tensorflow/image-classification/run_image_classification.py
...nsorflow/image-classification/run_image_classification.py
+1
-1
examples/tensorflow/language-modeling/run_clm.py
examples/tensorflow/language-modeling/run_clm.py
+1
-1
examples/tensorflow/language-modeling/run_mlm.py
examples/tensorflow/language-modeling/run_mlm.py
+1
-1
examples/tensorflow/question-answering/run_qa.py
examples/tensorflow/question-answering/run_qa.py
+1
-1
examples/tensorflow/text-classification/run_glue.py
examples/tensorflow/text-classification/run_glue.py
+3
-3
examples/tensorflow/text-classification/run_text_classification.py
...tensorflow/text-classification/run_text_classification.py
+4
-4
No files found.
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py
View file @
5e8c8eb5
...
...
@@ -145,18 +145,18 @@ class TestSummarizationDistiller(TestCasePlus):
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
updates
=
{
"
student_encoder_layers
"
:
2
,
"
student_decoder_layers
"
:
1
,
"
no_teacher
"
:
True
}
self
.
_test_distiller_cli
(
updates
)
def
test_distill_checkpointing_with_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
max_epochs
=
4
,
val_check_interval
=
0.25
,
alpha_hid
=
2.0
,
model_name_or_path
=
"IGNORE_THIS_IT_DOESNT_GET_USED"
,
)
updates
=
{
"
student_encoder_layers
"
:
2
,
"
student_decoder_layers
"
:
1
,
"
max_epochs
"
:
4
,
"
val_check_interval
"
:
0.25
,
"
alpha_hid
"
:
2.0
,
"
model_name_or_path
"
:
"IGNORE_THIS_IT_DOESNT_GET_USED"
,
}
model
=
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"*.ckpt"
))
...
...
@@ -193,19 +193,19 @@ class TestSummarizationDistiller(TestCasePlus):
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
def
test_distill_mbart
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
num_train_epochs
=
4
,
val_check_interval
=
0.25
,
alpha_hid
=
2.0
,
task
=
"translation"
,
model_name_or_path
=
"IGNORE_THIS_IT_DOESNT_GET_USED"
,
tokenizer_name
=
MBART_TINY
,
teacher
=
MBART_TINY
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
)
updates
=
{
"
student_encoder_layers
"
:
2
,
"
student_decoder_layers
"
:
1
,
"
num_train_epochs
"
:
4
,
"
val_check_interval
"
:
0.25
,
"
alpha_hid
"
:
2.0
,
"
task
"
:
"translation"
,
"
model_name_or_path
"
:
"IGNORE_THIS_IT_DOESNT_GET_USED"
,
"
tokenizer_name
"
:
MBART_TINY
,
"
teacher
"
:
MBART_TINY
,
"
src_lang
"
:
"en_XX"
,
"
tgt_lang
"
:
"ro_RO"
,
}
model
=
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
assert
model
.
model
.
config
.
model_type
==
"mbart"
...
...
@@ -217,39 +217,39 @@ class TestSummarizationDistiller(TestCasePlus):
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
def
test_distill_t5
(
self
):
updates
=
dict
(
student_encoder_layers
=
1
,
student_decoder_layers
=
1
,
alpha_hid
=
2.0
,
teacher
=
T5_TINY
,
model_name_or_path
=
T5_TINY
,
tokenizer_name
=
T5_TINY
,
)
updates
=
{
"
student_encoder_layers
"
:
1
,
"
student_decoder_layers
"
:
1
,
"
alpha_hid
"
:
2.0
,
"
teacher
"
:
T5_TINY
,
"
model_name_or_path
"
:
T5_TINY
,
"
tokenizer_name
"
:
T5_TINY
,
}
self
.
_test_distiller_cli
(
updates
)
def
test_distill_different_base_models
(
self
):
updates
=
dict
(
teacher
=
T5_TINY
,
student
=
T5_TINIER
,
model_name_or_path
=
T5_TINIER
,
tokenizer_name
=
T5_TINIER
,
)
updates
=
{
"
teacher
"
:
T5_TINY
,
"
student
"
:
T5_TINIER
,
"
model_name_or_path
"
:
T5_TINIER
,
"
tokenizer_name
"
:
T5_TINIER
,
}
self
.
_test_distiller_cli
(
updates
)
def
_test_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
default_updates
=
dict
(
label_smoothing
=
0.0
,
early_stopping_patience
=
-
1
,
train_batch_size
=
1
,
eval_batch_size
=
2
,
max_epochs
=
2
,
alpha_mlm
=
0.2
,
alpha_ce
=
0.8
,
do_predict
=
True
,
model_name_or_path
=
"sshleifer/tinier_bart"
,
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
val_check_interval
=
0.5
,
)
default_updates
=
{
"
label_smoothing
"
:
0.0
,
"
early_stopping_patience
"
:
-
1
,
"
train_batch_size
"
:
1
,
"
eval_batch_size
"
:
2
,
"
max_epochs
"
:
2
,
"
alpha_mlm
"
:
0.2
,
"
alpha_ce
"
:
0.8
,
"
do_predict
"
:
True
,
"
model_name_or_path
"
:
"sshleifer/tinier_bart"
,
"
teacher
"
:
CHEAP_ARGS
[
"model_name_or_path"
],
"
val_check_interval
"
:
0.5
,
}
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
...
...
examples/research_projects/seq2seq-distillation/_test_seq2seq_examples_multi_gpu.py
View file @
5e8c8eb5
...
...
@@ -98,29 +98,29 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
@
require_torch_multi_gpu
def
test_multi_gpu
(
self
):
updates
=
dict
(
no_teacher
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
overwrite_output_dir
=
True
,
sortish_sampler
=
True
,
)
updates
=
{
"
no_teacher
"
:
True
,
"
freeze_encoder
"
:
True
,
"
gpus
"
:
2
,
"
overwrite_output_dir
"
:
True
,
"
sortish_sampler
"
:
True
,
}
self
.
_test_distiller_cli_fork
(
updates
,
check_contents
=
False
)
def
_test_distiller_cli_fork
(
self
,
updates
,
check_contents
=
True
):
default_updates
=
dict
(
label_smoothing
=
0.0
,
early_stopping_patience
=
-
1
,
train_batch_size
=
1
,
eval_batch_size
=
2
,
max_epochs
=
2
,
alpha_mlm
=
0.2
,
alpha_ce
=
0.8
,
do_predict
=
True
,
model_name_or_path
=
"sshleifer/tinier_bart"
,
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
val_check_interval
=
0.5
,
)
default_updates
=
{
"
label_smoothing
"
:
0.0
,
"
early_stopping_patience
"
:
-
1
,
"
train_batch_size
"
:
1
,
"
eval_batch_size
"
:
2
,
"
max_epochs
"
:
2
,
"
alpha_mlm
"
:
0.2
,
"
alpha_ce
"
:
0.8
,
"
do_predict
"
:
True
,
"
model_name_or_path
"
:
"sshleifer/tinier_bart"
,
"
teacher
"
:
CHEAP_ARGS
[
"model_name_or_path"
],
"
val_check_interval
"
:
0.5
,
}
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
...
...
examples/research_projects/seq2seq-distillation/finetune.py
View file @
5e8c8eb5
...
...
@@ -74,11 +74,11 @@ class SummarizationModule(BaseTransformer):
self
.
model_type
=
self
.
config
.
model_type
self
.
vocab_size
=
self
.
config
.
tgt_vocab_size
if
self
.
model_type
==
"fsmt"
else
self
.
config
.
vocab_size
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
prefix
=
self
.
model
.
config
.
prefix
or
""
,
)
self
.
dataset_kwargs
:
dict
=
{
"
data_dir
"
:
self
.
hparams
.
data_dir
,
"
max_source_length
"
:
self
.
hparams
.
max_source_length
,
"
prefix
"
:
self
.
model
.
config
.
prefix
or
""
,
}
n_observations_per_split
=
{
"train"
:
self
.
hparams
.
n_train
,
"val"
:
self
.
hparams
.
n_val
,
...
...
@@ -433,7 +433,7 @@ def main(args, model=None) -> SummarizationModule:
return
model
model
.
hparams
.
test_checkpoint
=
""
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"*.ckpt"
),
recursive
=
True
))
)
checkpoints
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"*.ckpt"
),
recursive
=
True
))
if
checkpoints
:
model
.
hparams
.
test_checkpoint
=
checkpoints
[
-
1
]
trainer
.
resume_from_checkpoint
=
checkpoints
[
-
1
]
...
...
examples/research_projects/seq2seq-distillation/make_student.py
View file @
5e8c8eb5
...
...
@@ -171,11 +171,11 @@ def create_student_by_copying_alternating_layers(
logger
.
info
(
f
"Copied encoder layers
{
e_layers_to_copy
}
and decoder layers
{
d_layers_to_copy
}
. Saving them to
{
save_path
}
"
)
student
.
config
.
init_metadata
=
dict
(
teacher_type
=
teacher
.
config
.
model_type
,
copied_encoder_layers
=
e_layers_to_copy
,
copied_decoder_layers
=
d_layers_to_copy
,
)
student
.
config
.
init_metadata
=
{
"
teacher_type
"
:
teacher
.
config
.
model_type
,
"
copied_encoder_layers
"
:
e_layers_to_copy
,
"
copied_decoder_layers
"
:
d_layers_to_copy
,
}
student
.
save_pretrained
(
save_path
)
# Save information about copying for easier reproducibility
...
...
examples/research_projects/seq2seq-distillation/run_eval.py
View file @
5e8c8eb5
...
...
@@ -63,7 +63,7 @@ def generate_summaries_or_translations(
fout
.
close
()
runtime
=
int
(
time
.
time
()
-
start_time
)
# seconds
n_obs
=
len
(
examples
)
return
dict
(
n_obs
=
n_obs
,
runtime
=
runtime
,
seconds_per_sample
=
round
(
runtime
/
n_obs
,
4
)
)
return
{
"
n_obs
"
:
n_obs
,
"
runtime
"
:
runtime
,
"
seconds_per_sample
"
:
round
(
runtime
/
n_obs
,
4
)
}
def
datetime_now
():
...
...
examples/research_projects/seq2seq-distillation/utils.py
View file @
5e8c8eb5
...
...
@@ -437,7 +437,7 @@ def pickle_save(obj, path):
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)
]
return
list
(
itertools
.
chain
.
from_iterable
(
summary_ids
)
)
def
save_git_info
(
folder_path
:
str
)
->
None
:
...
...
examples/research_projects/tapex/wikisql_utils.py
View file @
5e8c8eb5
...
...
@@ -30,7 +30,7 @@ EMPTY_ANSWER_AGG = "none"
def
_split_thousands
(
delimiter
,
value
):
split
=
value
.
split
(
delimiter
)
return
len
(
split
)
>
1
and
any
(
map
(
lambda
x
:
len
(
x
)
==
3
,
split
))
return
len
(
split
)
>
1
and
any
(
(
len
(
x
)
==
3
for
x
in
split
))
def
convert_to_float
(
value
):
...
...
@@ -123,7 +123,7 @@ _TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL)
def
_normalize_for_match
(
x
):
return
[
t
for
t
in
_TOKENIZER
.
findall
(
x
.
lower
())
]
return
list
(
_TOKENIZER
.
findall
(
x
.
lower
())
)
def
_compare
(
operator
,
src
,
tgt
):
...
...
examples/research_projects/visual_bert/extracting_data.py
View file @
5e8c8eb5
...
...
@@ -61,7 +61,7 @@ class Extract:
assert
outputfile
is
not
None
and
not
os
.
path
.
isfile
(
outputfile
),
f
"
{
outputfile
}
"
if
subset_list
is
not
None
:
with
open
(
os
.
path
.
realpath
(
subset_list
))
as
f
:
self
.
subset_list
=
set
(
map
(
lambda
x
:
self
.
_vqa_file_split
()[
0
]
,
tryload
(
f
)
))
self
.
subset_list
=
{
self
.
_vqa_file_split
()[
0
]
for
x
in
tryload
(
f
)
}
else
:
self
.
subset_list
=
None
...
...
examples/research_projects/visual_bert/modeling_frcnn.py
View file @
5e8c8eb5
...
...
@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
Returns:
A tensor of shape(N*B, Channels, output_size, output_size)
"""
x
=
[
v
for
v
in
feature_maps
.
values
()
]
x
=
list
(
feature_maps
.
values
()
)
num_level_assignments
=
len
(
self
.
level_poolers
)
assert
len
(
x
)
==
num_level_assignments
and
len
(
boxes
)
==
x
[
0
].
size
(
0
)
...
...
examples/research_projects/vqgan-clip/VQGAN_CLIP.py
View file @
5e8c8eb5
...
...
@@ -99,7 +99,7 @@ class VQGAN_CLIP(nn.Module):
output_path
=
"./animation.gif"
if
input_path
is
None
:
input_path
=
self
.
save_path
paths
=
list
(
sorted
(
glob
(
input_path
+
"/*"
))
)
paths
=
sorted
(
glob
(
input_path
+
"/*"
))
if
not
len
(
paths
):
raise
ValueError
(
"No images found in save path, aborting (did you pass save_intermediate=True to the generate"
...
...
@@ -178,7 +178,7 @@ class VQGAN_CLIP(nn.Module):
wandb
.
init
(
reinit
=
True
,
project
=
"face-editor"
)
wandb
.
config
.
update
({
"Positive Prompts"
:
positive_prompts
})
wandb
.
config
.
update
({
"Negative Prompts"
:
negative_prompts
})
wandb
.
config
.
update
(
dict
(
lr
=
self
.
lr
,
iterations
=
self
.
iterations
)
)
wandb
.
config
.
update
(
{
"lr"
:
self
.
lr
,
"
iterations
"
:
self
.
iterations
}
)
if
image_path
:
image
=
Image
.
open
(
image_path
)
image
=
image
.
resize
((
256
,
256
))
...
...
examples/research_projects/vqgan-clip/loaders.py
View file @
5e8c8eb5
...
...
@@ -47,7 +47,7 @@ def get_obj_from_str(string, reload=False):
def
instantiate_from_config
(
config
):
if
"target"
not
in
config
:
raise
KeyError
(
"Expected key `target` to instantiate."
)
return
get_obj_from_str
(
config
[
"target"
])(
**
config
.
get
(
"params"
,
dict
()
))
return
get_obj_from_str
(
config
[
"target"
])(
**
config
.
get
(
"params"
,
{}
))
def
load_model_from_config
(
config
,
sd
,
gpu
=
True
,
eval_mode
=
True
):
...
...
examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py
View file @
5e8c8eb5
...
...
@@ -51,7 +51,7 @@ from transformers.trainer_utils import set_seed # noqa
set_seed
(
42
)
models
=
dict
(
base
=
"patrickvonplaten/wav2vec2_tiny_random"
,
robust
=
"patrickvonplaten/wav2vec2_tiny_random_robust"
)
models
=
{
"
base
"
:
"patrickvonplaten/wav2vec2_tiny_random"
,
"
robust
"
:
"patrickvonplaten/wav2vec2_tiny_random_robust"
}
ZERO2
=
"zero2"
ZERO3
=
"zero3"
...
...
examples/research_projects/xtreme-s/run_xtreme_s.py
View file @
5e8c8eb5
...
...
@@ -400,7 +400,7 @@ def create_vocabulary_from_data(
|
(
set
(
vocabs
[
"predict"
][
"vocab"
][
0
])
if
"predict"
in
vocabs
else
set
())
)
vocab_dict
=
{
v
:
k
for
k
,
v
in
enumerate
(
sorted
(
list
(
vocab_set
))
)
}
vocab_dict
=
{
v
:
k
for
k
,
v
in
enumerate
(
sorted
(
vocab_set
))}
# replace white space with delimiter token
if
word_delimiter_token
is
not
None
:
...
...
examples/tensorflow/benchmarking/plot_csv_file.py
View file @
5e8c8eb5
...
...
@@ -83,7 +83,7 @@ def can_convert_to_float(string):
class
Plot
:
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
result_dict
=
defaultdict
(
lambda
:
dict
(
bsz
=
[],
seq_len
=
[],
result
=
{}
)
)
self
.
result_dict
=
defaultdict
(
lambda
:
{
"bsz"
:
[],
"
seq_len
"
:
[],
"
result
"
:
{}
}
)
with
open
(
self
.
args
.
csv_file
,
newline
=
""
)
as
csv_file
:
reader
=
csv
.
DictReader
(
csv_file
)
...
...
@@ -116,8 +116,8 @@ class Plot:
axis
.
set_major_formatter
(
ScalarFormatter
())
for
model_name_idx
,
model_name
in
enumerate
(
self
.
result_dict
.
keys
()):
batch_sizes
=
sorted
(
list
(
set
(
self
.
result_dict
[
model_name
][
"bsz"
]))
)
sequence_lengths
=
sorted
(
list
(
set
(
self
.
result_dict
[
model_name
][
"seq_len"
]))
)
batch_sizes
=
sorted
(
set
(
self
.
result_dict
[
model_name
][
"bsz"
]))
sequence_lengths
=
sorted
(
set
(
self
.
result_dict
[
model_name
][
"seq_len"
]))
results
=
self
.
result_dict
[
model_name
][
"result"
]
(
x_axis_array
,
inner_loop_array
)
=
(
...
...
examples/tensorflow/image-classification/run_image_classification.py
View file @
5e8c8eb5
...
...
@@ -300,7 +300,7 @@ def main():
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels
=
dataset
[
"train"
].
features
[
"labels"
].
names
label2id
,
id2label
=
dict
(),
dict
()
label2id
,
id2label
=
{},
{}
for
i
,
label
in
enumerate
(
labels
):
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
...
...
examples/tensorflow/language-modeling/run_clm.py
View file @
5e8c8eb5
...
...
@@ -600,7 +600,7 @@ def main():
if
training_args
.
output_dir
is
not
None
:
output_eval_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"all_results.json"
)
results_dict
=
dict
()
results_dict
=
{}
results_dict
[
"train_loss"
]
=
train_loss
results_dict
[
"train_perplexity"
]
=
train_perplexity
results_dict
[
"eval_loss"
]
=
validation_loss
...
...
examples/tensorflow/language-modeling/run_mlm.py
View file @
5e8c8eb5
...
...
@@ -623,7 +623,7 @@ def main():
if
training_args
.
output_dir
is
not
None
:
output_eval_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"all_results.json"
)
results_dict
=
dict
()
results_dict
=
{}
results_dict
[
"train_loss"
]
=
train_loss
results_dict
[
"train_perplexity"
]
=
train_perplexity
results_dict
[
"eval_loss"
]
=
validation_loss
...
...
examples/tensorflow/question-answering/run_qa.py
View file @
5e8c8eb5
...
...
@@ -464,7 +464,7 @@ def main():
return
tokenized_examples
processed_datasets
=
dict
()
processed_datasets
=
{}
if
training_args
.
do_train
:
if
"train"
not
in
datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
...
...
examples/tensorflow/text-classification/run_glue.py
View file @
5e8c8eb5
...
...
@@ -310,12 +310,12 @@ def main():
if
config
.
label2id
!=
PretrainedConfig
(
num_labels
=
num_labels
).
label2id
and
not
is_regression
:
# Some have all caps in their config, some don't.
label_name_to_id
=
{
k
.
lower
():
v
for
k
,
v
in
config
.
label2id
.
items
()}
if
list
(
sorted
(
label_name_to_id
.
keys
())
)
==
list
(
sorted
(
label_list
)
)
:
if
sorted
(
label_name_to_id
.
keys
())
==
sorted
(
label_list
):
label_to_id
=
{
i
:
int
(
label_name_to_id
[
label_list
[
i
]])
for
i
in
range
(
num_labels
)}
else
:
logger
.
warning
(
"Your model seems to have been trained with labels, but they don't match the dataset: "
,
f
"model labels:
{
list
(
sorted
(
label_name_to_id
.
keys
())
)
}
, dataset labels:
{
list
(
sorted
(
label_list
)
)
}
."
f
"model labels:
{
sorted
(
label_name_to_id
.
keys
())
}
, dataset labels:
{
sorted
(
label_list
)
}
."
"
\n
Ignoring the model labels as a result."
,
)
label_to_id
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
...
...
@@ -383,7 +383,7 @@ def main():
dataset_options
.
experimental_distribute
.
auto_shard_policy
=
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
num_replicas
=
training_args
.
strategy
.
num_replicas_in_sync
tf_data
=
dict
()
tf_data
=
{}
max_samples
=
{
"train"
:
data_args
.
max_train_samples
,
"validation"
:
data_args
.
max_eval_samples
,
...
...
examples/tensorflow/text-classification/run_text_classification.py
View file @
5e8c8eb5
...
...
@@ -343,13 +343,13 @@ def main():
if
"train"
in
datasets
:
if
not
is_regression
and
config
.
label2id
!=
PretrainedConfig
(
num_labels
=
num_labels
).
label2id
:
label_name_to_id
=
config
.
label2id
if
list
(
sorted
(
label_name_to_id
.
keys
())
)
==
list
(
sorted
(
label_list
)
)
:
if
sorted
(
label_name_to_id
.
keys
())
==
sorted
(
label_list
):
label_to_id
=
label_name_to_id
# Use the model's labels
else
:
logger
.
warning
(
"Your model seems to have been trained with labels, but they don't match the dataset: "
,
f
"model labels:
{
list
(
sorted
(
label_name_to_id
.
keys
())
)
}
, dataset labels:"
f
"
{
list
(
sorted
(
label_list
)
)
}
.
\n
Ignoring the model labels as a result."
,
f
"model labels:
{
sorted
(
label_name_to_id
.
keys
())
}
, dataset labels:"
f
"
{
sorted
(
label_list
)
}
.
\n
Ignoring the model labels as a result."
,
)
label_to_id
=
{
v
:
i
for
i
,
v
in
enumerate
(
label_list
)}
elif
not
is_regression
:
...
...
@@ -411,7 +411,7 @@ def main():
dataset_options
.
experimental_distribute
.
auto_shard_policy
=
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
num_replicas
=
training_args
.
strategy
.
num_replicas_in_sync
tf_data
=
dict
()
tf_data
=
{}
max_samples
=
{
"train"
:
data_args
.
max_train_samples
,
"validation"
:
data_args
.
max_val_samples
,
...
...
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment