Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
92 additions
and
92 deletions
+92
-92
pyproject.toml
pyproject.toml
+1
-1
src/transformers/benchmark/benchmark_utils.py
src/transformers/benchmark/benchmark_utils.py
+3
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+2
-2
src/transformers/deepspeed.py
src/transformers/deepspeed.py
+7
-7
src/transformers/feature_extraction_sequence_utils.py
src/transformers/feature_extraction_sequence_utils.py
+1
-1
src/transformers/generation/beam_constraints.py
src/transformers/generation/beam_constraints.py
+2
-2
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+1
-1
src/transformers/image_utils.py
src/transformers/image_utils.py
+1
-1
src/transformers/integrations.py
src/transformers/integrations.py
+3
-3
src/transformers/keras_callbacks.py
src/transformers/keras_callbacks.py
+1
-1
src/transformers/modelcard.py
src/transformers/modelcard.py
+4
-4
src/transformers/modeling_flax_pytorch_utils.py
src/transformers/modeling_flax_pytorch_utils.py
+7
-7
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+14
-14
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+2
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+19
-19
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+17
-17
src/transformers/models/beit/modeling_flax_beit.py
src/transformers/models/beit/modeling_flax_beit.py
+1
-1
src/transformers/models/bertweet/tokenization_bertweet.py
src/transformers/models/bertweet/tokenization_bertweet.py
+2
-2
src/transformers/models/big_bird/tokenization_big_bird_fast.py
...ransformers/models/big_bird/tokenization_big_bird_fast.py
+1
-1
src/transformers/models/biogpt/tokenization_biogpt.py
src/transformers/models/biogpt/tokenization_biogpt.py
+3
-3
No files found.
pyproject.toml
View file @
5e8c8eb5
...
...
@@ -5,7 +5,7 @@ target-version = ['py37']
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore
=
[
"E501"
,
"E741"
,
"W605"
]
select
=
[
"E"
,
"F"
,
"I"
,
"W"
]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"W"
]
line-length
=
119
# Ignore import violations in all `__init__.py` files.
...
...
src/transformers/benchmark/benchmark_utils.py
View file @
5e8c8eb5
...
...
@@ -557,9 +557,9 @@ def stop_memory_tracing(
cumulative_memory_dict
[
frame
][
2
]
+=
cpu_gpu_mem_inc
cumulative_memory
=
sorted
(
list
(
cumulative_memory_dict
.
items
()
)
,
key
=
lambda
x
:
x
[
1
][
2
],
reverse
=
True
cumulative_memory_dict
.
items
(),
key
=
lambda
x
:
x
[
1
][
2
],
reverse
=
True
)
# order by the total CPU + GPU memory increase
cumulative_memory
=
list
(
cumulative_memory
=
[
MemoryState
(
frame
=
frame
,
cpu
=
Memory
(
cpu_mem_inc
),
...
...
@@ -567,7 +567,7 @@ def stop_memory_tracing(
cpu_gpu
=
Memory
(
cpu_gpu_mem_inc
),
)
for
frame
,
(
cpu_mem_inc
,
gpu_mem_inc
,
cpu_gpu_mem_inc
)
in
cumulative_memory
)
]
memory_curr_trace
=
sorted
(
memory_curr_trace
,
key
=
lambda
x
:
x
.
cpu_gpu
.
bytes
,
reverse
=
True
)
...
...
src/transformers/configuration_utils.py
View file @
5e8c8eb5
...
...
@@ -324,7 +324,7 @@ class PretrainedConfig(PushToHubMixin):
f
"You passed along `num_labels=
{
num_labels
}
` with an incompatible id to label map: "
f
"
{
self
.
id2label
}
. The number of labels wil be overwritten to
{
self
.
num_labels
}
."
)
self
.
id2label
=
dict
((
int
(
key
)
,
value
)
for
key
,
value
in
self
.
id2label
.
items
()
)
self
.
id2label
=
{
int
(
key
)
:
value
for
key
,
value
in
self
.
id2label
.
items
()
}
# Keys are always strings in JSON so convert ids to int here.
else
:
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
...
...
@@ -696,7 +696,7 @@ class PretrainedConfig(PushToHubMixin):
config
=
cls
(
**
config_dict
)
if
hasattr
(
config
,
"pruned_heads"
):
config
.
pruned_heads
=
dict
((
int
(
key
)
,
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
()
)
config
.
pruned_heads
=
{
int
(
key
)
:
value
for
key
,
value
in
config
.
pruned_heads
.
items
()
}
# Update config with kwargs if needed
if
"num_labels"
in
kwargs
and
"id2label"
in
kwargs
:
...
...
src/transformers/deepspeed.py
View file @
5e8c8eb5
...
...
@@ -367,13 +367,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
# keep for quick debug:
# from pprint import pprint; pprint(config)
kwargs
=
dict
(
model
=
model
,
model_parameters
=
model_parameters
,
config_params
=
config
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
)
kwargs
=
{
"
model
"
:
model
,
"
model_parameters
"
:
model_parameters
,
"
config_params
"
:
config
,
"
optimizer
"
:
optimizer
,
"
lr_scheduler
"
:
lr_scheduler
,
}
deepspeed_engine
,
optimizer
,
_
,
lr_scheduler
=
deepspeed
.
initialize
(
**
kwargs
)
...
...
src/transformers/feature_extraction_sequence_utils.py
View file @
5e8c8eb5
...
...
@@ -188,7 +188,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
truncated_inputs
=
[]
for
i
in
range
(
batch_size
):
inputs
=
dict
((
k
,
v
[
i
]
)
for
k
,
v
in
processed_features
.
items
()
)
inputs
=
{
k
:
v
[
i
]
for
k
,
v
in
processed_features
.
items
()
}
# truncation
inputs_slice
=
self
.
_truncate
(
inputs
,
...
...
src/transformers/generation/beam_constraints.py
View file @
5e8c8eb5
...
...
@@ -208,12 +208,12 @@ class DisjunctiveTrie:
"""
self
.
max_height
=
max
([
len
(
one
)
for
one
in
nested_token_ids
])
root
=
dict
()
root
=
{}
for
token_ids
in
nested_token_ids
:
level
=
root
for
tidx
,
token_id
in
enumerate
(
token_ids
):
if
token_id
not
in
level
:
level
[
token_id
]
=
dict
()
level
[
token_id
]
=
{}
level
=
level
[
token_id
]
...
...
src/transformers/generation/logits_process.py
View file @
5e8c8eb5
...
...
@@ -951,7 +951,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for
k
in
range
(
input_ids
.
shape
[
0
]):
seq
=
[
t
for
t
in
input_ids
[
k
,
self
.
begin_index
:].
tolist
()
]
seq
=
list
(
input_ids
[
k
,
self
.
begin_index
:].
tolist
()
)
last_was_timestamp
=
len
(
seq
)
>=
1
and
seq
[
-
1
]
>=
self
.
timestamp_begin
penultimate_was_timestamp
=
len
(
seq
)
<
2
or
seq
[
-
2
]
>=
self
.
timestamp_begin
...
...
src/transformers/image_utils.py
View file @
5e8c8eb5
...
...
@@ -115,7 +115,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
if
is_valid_image
(
images
):
if
images
.
ndim
==
expected_ndims
+
1
:
# Batch of images
images
=
[
image
for
image
in
images
]
images
=
list
(
images
)
elif
images
.
ndim
==
expected_ndims
:
# Single image
images
=
[
images
]
...
...
src/transformers/integrations.py
View file @
5e8c8eb5
...
...
@@ -365,7 +365,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
name
=
"huggingface-tune"
,
type
=
"offline"
,
parameters
=
trainer
.
hp_space
(
None
),
metrics
=
[
dict
(
name
=
"objective"
,
objective
=
direction
,
strategy
=
"optimize"
)
],
metrics
=
[
{
"
name
"
:
"objective"
,
"
objective
"
:
direction
,
"
strategy
"
:
"optimize"
}
],
parallel_bandwidth
=
1
,
budget
=
n_trials
,
)
...
...
@@ -402,7 +402,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
experiment
=
conn
.
experiments
().
create
(
name
=
"huggingface-tune"
,
parameters
=
trainer
.
hp_space
(
None
),
metrics
=
[
dict
(
name
=
"objective"
,
objective
=
direction
,
strategy
=
"optimize"
)
],
metrics
=
[
{
"
name
"
:
"objective"
,
"
objective
"
:
direction
,
"
strategy
"
:
"optimize"
}
],
parallel_bandwidth
=
1
,
observation_budget
=
n_trials
,
project
=
"huggingface"
,
...
...
@@ -425,7 +425,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
metrics
=
trainer
.
evaluate
()
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
values
=
[
dict
(
name
=
"objective"
,
value
=
trainer
.
objective
)
]
values
=
[
{
"
name
"
:
"objective"
,
"
value
"
:
trainer
.
objective
}
]
obs
=
conn
.
experiments
(
experiment
.
id
).
observations
().
create
(
suggestion
=
suggestion
.
id
,
values
=
values
)
logger
.
info
(
f
"[suggestion_id, observation_id]: [
{
suggestion
.
id
}
,
{
obs
.
id
}
]"
)
experiment
=
conn
.
experiments
(
experiment
.
id
).
fetch
()
...
...
src/transformers/keras_callbacks.py
View file @
5e8c8eb5
...
...
@@ -162,7 +162,7 @@ class KerasMetricCallback(Callback):
def
_postprocess_predictions_or_labels
(
self
,
inputs
):
if
isinstance
(
inputs
[
0
],
dict
):
outputs
=
dict
()
outputs
=
{}
for
key
in
inputs
[
0
].
keys
():
outputs
[
key
]
=
self
.
_concatenate_batches
([
batch
[
key
]
for
batch
in
inputs
])
# If it's a dict with only one key, just return the array
...
...
src/transformers/modelcard.py
View file @
5e8c8eb5
...
...
@@ -677,7 +677,7 @@ class TrainingSummary:
_
,
eval_lines
,
eval_results
=
parse_keras_history
(
keras_history
)
else
:
eval_lines
=
[]
eval_results
=
dict
()
eval_results
=
{}
hyperparameters
=
extract_hyperparameters_from_keras
(
model
)
return
cls
(
...
...
@@ -706,7 +706,7 @@ def parse_keras_history(logs):
# This looks like a `History` object
if
not
hasattr
(
logs
,
"epoch"
):
# This history looks empty, return empty results
return
None
,
[],
dict
()
return
None
,
[],
{}
logs
.
history
[
"epoch"
]
=
logs
.
epoch
logs
=
logs
.
history
else
:
...
...
@@ -716,7 +716,7 @@ def parse_keras_history(logs):
lines
=
[]
for
i
in
range
(
len
(
logs
[
"epoch"
])):
epoch_dict
=
{
log_key
:
log_value_list
[
i
]
for
log_key
,
log_value_list
in
logs
.
items
()}
values
=
dict
()
values
=
{}
for
k
,
v
in
epoch_dict
.
items
():
if
k
.
startswith
(
"val_"
):
k
=
"validation_"
+
k
[
4
:]
...
...
@@ -797,7 +797,7 @@ def parse_log_history(log_history):
def
extract_hyperparameters_from_keras
(
model
):
import
tensorflow
as
tf
hyperparameters
=
dict
()
hyperparameters
=
{}
if
hasattr
(
model
,
"optimizer"
)
and
model
.
optimizer
is
not
None
:
hyperparameters
[
"optimizer"
]
=
model
.
optimizer
.
get_config
()
else
:
...
...
src/transformers/modeling_flax_pytorch_utils.py
View file @
5e8c8eb5
...
...
@@ -76,7 +76,7 @@ def rename_key_and_reshape_tensor(
def
is_key_or_prefix_key_in_dict
(
key
:
Tuple
[
str
])
->
bool
:
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
return
len
(
set
(
random_flax_state_dict
)
&
set
([
key
,
(
model_prefix
,)
+
key
])
)
>
0
return
len
(
set
(
random_flax_state_dict
)
&
{
key
,
(
model_prefix
,)
+
key
}
)
>
0
# layer norm
renamed_pt_tuple_key
=
pt_tuple_key
[:
-
1
]
+
(
"scale"
,)
...
...
@@ -122,10 +122,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict
=
{}
load_model_with_head_into_base_model
=
(
model_prefix
not
in
flax_model
.
params
)
and
(
model_prefix
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
load_base_model_into_model_with_head
=
(
model_prefix
in
flax_model
.
params
)
and
(
model_prefix
not
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
not
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
# Need to change some parameters name to match Flax names
...
...
@@ -179,10 +179,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
random_flax_state_dict
=
flatten_dict
(
flax_model
.
params
)
load_model_with_head_into_base_model
=
(
model_prefix
not
in
flax_model
.
params
)
and
(
model_prefix
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
load_base_model_into_model_with_head
=
(
model_prefix
in
flax_model
.
params
)
and
(
model_prefix
not
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
not
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
# Need to change some parameters name to match Flax names
for
pt_key
,
pt_tensor
in
pt_state_dict
.
items
():
...
...
@@ -267,10 +267,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
pt_model_dict
=
pt_model
.
state_dict
()
load_model_with_head_into_base_model
=
(
pt_model
.
base_model_prefix
in
flax_state
)
and
(
pt_model
.
base_model_prefix
not
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
])
pt_model
.
base_model_prefix
not
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
}
)
load_base_model_into_model_with_head
=
(
pt_model
.
base_model_prefix
not
in
flax_state
)
and
(
pt_model
.
base_model_prefix
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
])
pt_model
.
base_model_prefix
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
}
)
# keep track of unexpected & missing keys
...
...
src/transformers/modeling_flax_utils.py
View file @
5e8c8eb5
...
...
@@ -440,7 +440,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"""
# Load the index
state_sharded_dict
=
dict
()
state_sharded_dict
=
{}
for
shard_file
in
shard_files
:
# load using msgpack utils
...
...
@@ -708,19 +708,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
filename
=
WEIGHTS_NAME
if
from_pt
else
FLAX_WEIGHTS_NAME
try
:
# Load from URL or cache if already cached
cached_file_kwargs
=
dict
(
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
subfolder
=
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
_commit_hash
=
commit_hash
,
)
cached_file_kwargs
=
{
"
cache_dir
"
:
cache_dir
,
"
force_download
"
:
force_download
,
"
proxies
"
:
proxies
,
"
resume_download
"
:
resume_download
,
"
local_files_only
"
:
local_files_only
,
"
use_auth_token
"
:
use_auth_token
,
"
user_agent
"
:
user_agent
,
"
revision
"
:
revision
,
"
subfolder
"
:
subfolder
,
"
_raise_exceptions_for_missing_entries
"
:
False
,
"
_commit_hash
"
:
commit_hash
,
}
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
...
...
src/transformers/modeling_tf_pytorch_utils.py
View file @
5e8c8eb5
...
...
@@ -258,7 +258,7 @@ def load_pytorch_state_dict_in_tf2_model(
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
tf_loaded_numel
=
0
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
())
)
all_pytorch_weights
=
set
(
pt_state_dict
.
keys
())
missing_keys
=
[]
for
symbolic_weight
in
symbolic_weights
:
sw_name
=
symbolic_weight
.
name
...
...
@@ -425,7 +425,7 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
)
tf_weights_map
[
pt_name
]
=
(
tf_weight
,
transpose
)
all_tf_weights
=
set
(
list
(
tf_weights_map
.
keys
())
)
all_tf_weights
=
set
(
tf_weights_map
.
keys
())
loaded_pt_weights_data_ptr
=
{}
missing_keys_pt
=
[]
for
pt_weight_name
,
pt_weight
in
current_pt_params_dict
.
items
():
...
...
src/transformers/modeling_tf_utils.py
View file @
5e8c8eb5
...
...
@@ -584,7 +584,7 @@ def input_processing(func, config, **kwargs):
if
"kwargs"
in
output
:
del
output
[
"kwargs"
]
cast_output
=
dict
()
cast_output
=
{}
for
key
,
val
in
output
.
items
():
if
isinstance
(
val
,
tf
.
Tensor
)
and
val
.
dtype
==
tf
.
int64
:
cast_output
[
key
]
=
tf
.
cast
(
val
,
tf
.
int32
)
...
...
@@ -737,7 +737,7 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer.
model_keys
=
set
()
model_layer_map
=
dict
()
model_layer_map
=
{}
for
i
,
k
in
enumerate
(
model
.
weights
):
if
"model."
in
k
.
name
or
len
(
k
.
name
.
split
(
"/"
))
==
1
:
layer_name
=
k
.
name
...
...
@@ -901,10 +901,10 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
)
# Find the missing layers from the high level list of layers
missing_layers
=
list
(
set
([
layer
.
name
for
layer
in
model
.
layers
])
-
saved_h5_model_layers_name
)
missing_layers
=
list
(
{
layer
.
name
for
layer
in
model
.
layers
}
-
saved_h5_model_layers_name
)
# Find the unexpected layers from the high level list of layers
unexpected_layers
=
list
(
saved_h5_model_layers_name
-
set
([
layer
.
name
for
layer
in
model
.
layers
])
)
unexpected_layers
=
list
(
saved_h5_model_layers_name
-
{
layer
.
name
for
layer
in
model
.
layers
}
)
saved_weight_names_set
=
set
()
symbolic_weights_names
=
set
()
weight_value_tuples
=
[]
...
...
@@ -1349,7 +1349,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else
:
collate_fn
=
DataCollatorWithPadding
(
tokenizer
=
tokenizer
,
return_tensors
=
"np"
)
if
collate_fn_args
is
None
:
collate_fn_args
=
dict
()
collate_fn_args
=
{}
if
not
isinstance
(
dataset
,
datasets
.
Dataset
):
raise
TypeError
(
"Dataset argument should be a datasets.Dataset!"
)
...
...
@@ -1471,7 +1471,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif
"mc_labels"
in
arg_names
:
return
{
"labels"
:
"logits"
,
"mc_labels"
:
"mc_logits"
}
else
:
return
dict
()
return
{}
def
train_step
(
self
,
data
):
"""
...
...
@@ -2613,19 +2613,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try
:
# Load from URL or cache if already cached
cached_file_kwargs
=
dict
(
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
subfolder
=
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
_commit_hash
=
commit_hash
,
)
cached_file_kwargs
=
{
"
cache_dir
"
:
cache_dir
,
"
force_download
"
:
force_download
,
"
proxies
"
:
proxies
,
"
resume_download
"
:
resume_download
,
"
local_files_only
"
:
local_files_only
,
"
use_auth_token
"
:
use_auth_token
,
"
user_agent
"
:
user_agent
,
"
revision
"
:
revision
,
"
subfolder
"
:
subfolder
,
"
_raise_exceptions_for_missing_entries
"
:
False
,
"
_commit_hash
"
:
commit_hash
,
}
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
...
...
src/transformers/modeling_utils.py
View file @
5e8c8eb5
...
...
@@ -1271,7 +1271,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
len
(
encoder_modules
)
>
0
),
f
"Encoder module
{
encoder_pointer
}
does not match decoder module
{
decoder_pointer
}
"
all_encoder_weights
=
set
([
module_name
+
"/"
+
sub_name
for
sub_name
in
encoder_modules
.
keys
()
])
all_encoder_weights
=
{
module_name
+
"/"
+
sub_name
for
sub_name
in
encoder_modules
.
keys
()
}
encoder_layer_pos
=
0
for
name
,
module
in
decoder_modules
.
items
():
if
name
.
isdigit
():
...
...
@@ -2304,19 +2304,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
try
:
# Load from URL or cache if already cached
cached_file_kwargs
=
dict
(
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
subfolder
=
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
_commit_hash
=
commit_hash
,
)
cached_file_kwargs
=
{
"
cache_dir
"
:
cache_dir
,
"
force_download
"
:
force_download
,
"
proxies
"
:
proxies
,
"
resume_download
"
:
resume_download
,
"
local_files_only
"
:
local_files_only
,
"
use_auth_token
"
:
use_auth_token
,
"
user_agent
"
:
user_agent
,
"
revision
"
:
revision
,
"
subfolder
"
:
subfolder
,
"
_raise_exceptions_for_missing_entries
"
:
False
,
"
_commit_hash
"
:
commit_hash
,
}
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
...
...
@@ -2474,7 +2474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
is_sharded
:
loaded_state_dict_keys
=
sharded_metadata
[
"all_checkpoint_keys"
]
else
:
loaded_state_dict_keys
=
[
k
for
k
in
state_dict
.
keys
()
]
loaded_state_dict_keys
=
list
(
state_dict
.
keys
()
)
if
low_cpu_mem_usage
or
use_keep_in_fp32_modules
:
state_dict
=
None
...
...
@@ -3046,12 +3046,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
offload_index
,
error_msgs
def
retrieve_modules_from_names
(
self
,
names
,
add_prefix
=
False
,
remove_prefix
=
False
):
module_keys
=
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
for
key
in
names
])
module_keys
=
{
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
for
key
in
names
}
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys
=
module_keys
.
union
(
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
2
])
for
key
in
names
if
len
(
key
)
>
0
and
key
[
-
1
].
isdigit
()
])
{
"."
.
join
(
key
.
split
(
"."
)[:
-
2
])
for
key
in
names
if
len
(
key
)
>
0
and
key
[
-
1
].
isdigit
()
}
)
retrieved_modules
=
[]
...
...
src/transformers/models/beit/modeling_flax_beit.py
View file @
5e8c8eb5
...
...
@@ -555,7 +555,7 @@ class FlaxBeitEncoder(nn.Module):
)
# stochastic depth decay rule
drop_path_rates
=
[
x
for
x
in
np
.
linspace
(
0
,
self
.
config
.
drop_path_rate
,
self
.
config
.
num_hidden_layers
)
]
drop_path_rates
=
list
(
np
.
linspace
(
0
,
self
.
config
.
drop_path_rate
,
self
.
config
.
num_hidden_layers
)
)
self
.
layer
=
FlaxBeitLayerCollection
(
self
.
config
,
window_size
=
self
.
window_size
,
...
...
src/transformers/models/bertweet/tokenization_bertweet.py
View file @
5e8c8eb5
...
...
@@ -318,7 +318,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
split_tokens
=
[]
words
=
re
.
findall
(
r
"\S+\n?"
,
text
)
for
token
in
words
:
split_tokens
.
extend
(
[
t
for
t
in
self
.
bpe
(
token
).
split
(
" "
)
]
)
split_tokens
.
extend
(
list
(
self
.
bpe
(
token
).
split
(
" "
)
)
)
return
split_tokens
def
normalizeTweet
(
self
,
tweet
):
...
...
@@ -726,7 +726,7 @@ class TweetTokenizer:
words
=
WORD_RE
.
findall
(
safe_text
)
# Possibly alter the case, but avoid changing emoticons like :D into :d:
if
not
self
.
preserve_case
:
words
=
list
(
map
((
lambda
x
:
x
if
EMOTICON_RE
.
search
(
x
)
else
x
.
lower
()
),
words
))
words
=
[
x
if
EMOTICON_RE
.
search
(
x
)
else
x
.
lower
()
for
x
in
words
]
return
words
...
...
src/transformers/models/big_bird/tokenization_big_bird_fast.py
View file @
5e8c8eb5
...
...
@@ -202,7 +202,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return
list
(
map
(
lambda
x
:
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
,
token_ids_0
))
return
[
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
for
x
in
token_ids_0
]
if
token_ids_1
is
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
...
...
src/transformers/models/biogpt/tokenization_biogpt.py
View file @
5e8c8eb5
...
...
@@ -132,8 +132,8 @@ class BioGptTokenizer(PreTrainedTokenizer):
self
.
lang
=
"en"
self
.
sm
=
sacremoses
# cache of sm.MosesTokenizer instance
self
.
cache_moses_tokenizer
=
dict
()
self
.
cache_moses_detokenizer
=
dict
()
self
.
cache_moses_tokenizer
=
{}
self
.
cache_moses_detokenizer
=
{}
""" Initialisation"""
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
...
...
@@ -221,7 +221,7 @@ class BioGptTokenizer(PreTrainedTokenizer):
split_tokens
=
[]
for
token
in
text
:
if
token
:
split_tokens
.
extend
(
[
t
for
t
in
self
.
bpe
(
token
).
split
(
" "
)
]
)
split_tokens
.
extend
(
list
(
self
.
bpe
(
token
).
split
(
" "
)
)
)
return
split_tokens
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment