Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
"tests/rembert/test_modeling_rembert.py" did not exist on "3425936643b157bda181af169b371dcf0a3ad3eb"
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
92 additions
and
92 deletions
+92
-92
pyproject.toml
pyproject.toml
+1
-1
src/transformers/benchmark/benchmark_utils.py
src/transformers/benchmark/benchmark_utils.py
+3
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+2
-2
src/transformers/deepspeed.py
src/transformers/deepspeed.py
+7
-7
src/transformers/feature_extraction_sequence_utils.py
src/transformers/feature_extraction_sequence_utils.py
+1
-1
src/transformers/generation/beam_constraints.py
src/transformers/generation/beam_constraints.py
+2
-2
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+1
-1
src/transformers/image_utils.py
src/transformers/image_utils.py
+1
-1
src/transformers/integrations.py
src/transformers/integrations.py
+3
-3
src/transformers/keras_callbacks.py
src/transformers/keras_callbacks.py
+1
-1
src/transformers/modelcard.py
src/transformers/modelcard.py
+4
-4
src/transformers/modeling_flax_pytorch_utils.py
src/transformers/modeling_flax_pytorch_utils.py
+7
-7
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+14
-14
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+2
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+19
-19
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+17
-17
src/transformers/models/beit/modeling_flax_beit.py
src/transformers/models/beit/modeling_flax_beit.py
+1
-1
src/transformers/models/bertweet/tokenization_bertweet.py
src/transformers/models/bertweet/tokenization_bertweet.py
+2
-2
src/transformers/models/big_bird/tokenization_big_bird_fast.py
...ransformers/models/big_bird/tokenization_big_bird_fast.py
+1
-1
src/transformers/models/biogpt/tokenization_biogpt.py
src/transformers/models/biogpt/tokenization_biogpt.py
+3
-3
No files found.
pyproject.toml
View file @
5e8c8eb5
...
@@ -5,7 +5,7 @@ target-version = ['py37']
...
@@ -5,7 +5,7 @@ target-version = ['py37']
[tool.ruff]
[tool.ruff]
# Never enforce `E501` (line length violations).
# Never enforce `E501` (line length violations).
ignore
=
[
"E501"
,
"E741"
,
"W605"
]
ignore
=
[
"E501"
,
"E741"
,
"W605"
]
select
=
[
"E"
,
"F"
,
"I"
,
"W"
]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"W"
]
line-length
=
119
line-length
=
119
# Ignore import violations in all `__init__.py` files.
# Ignore import violations in all `__init__.py` files.
...
...
src/transformers/benchmark/benchmark_utils.py
View file @
5e8c8eb5
...
@@ -557,9 +557,9 @@ def stop_memory_tracing(
...
@@ -557,9 +557,9 @@ def stop_memory_tracing(
cumulative_memory_dict
[
frame
][
2
]
+=
cpu_gpu_mem_inc
cumulative_memory_dict
[
frame
][
2
]
+=
cpu_gpu_mem_inc
cumulative_memory
=
sorted
(
cumulative_memory
=
sorted
(
list
(
cumulative_memory_dict
.
items
()
)
,
key
=
lambda
x
:
x
[
1
][
2
],
reverse
=
True
cumulative_memory_dict
.
items
(),
key
=
lambda
x
:
x
[
1
][
2
],
reverse
=
True
)
# order by the total CPU + GPU memory increase
)
# order by the total CPU + GPU memory increase
cumulative_memory
=
list
(
cumulative_memory
=
[
MemoryState
(
MemoryState
(
frame
=
frame
,
frame
=
frame
,
cpu
=
Memory
(
cpu_mem_inc
),
cpu
=
Memory
(
cpu_mem_inc
),
...
@@ -567,7 +567,7 @@ def stop_memory_tracing(
...
@@ -567,7 +567,7 @@ def stop_memory_tracing(
cpu_gpu
=
Memory
(
cpu_gpu_mem_inc
),
cpu_gpu
=
Memory
(
cpu_gpu_mem_inc
),
)
)
for
frame
,
(
cpu_mem_inc
,
gpu_mem_inc
,
cpu_gpu_mem_inc
)
in
cumulative_memory
for
frame
,
(
cpu_mem_inc
,
gpu_mem_inc
,
cpu_gpu_mem_inc
)
in
cumulative_memory
)
]
memory_curr_trace
=
sorted
(
memory_curr_trace
,
key
=
lambda
x
:
x
.
cpu_gpu
.
bytes
,
reverse
=
True
)
memory_curr_trace
=
sorted
(
memory_curr_trace
,
key
=
lambda
x
:
x
.
cpu_gpu
.
bytes
,
reverse
=
True
)
...
...
src/transformers/configuration_utils.py
View file @
5e8c8eb5
...
@@ -324,7 +324,7 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -324,7 +324,7 @@ class PretrainedConfig(PushToHubMixin):
f
"You passed along `num_labels=
{
num_labels
}
` with an incompatible id to label map: "
f
"You passed along `num_labels=
{
num_labels
}
` with an incompatible id to label map: "
f
"
{
self
.
id2label
}
. The number of labels wil be overwritten to
{
self
.
num_labels
}
."
f
"
{
self
.
id2label
}
. The number of labels wil be overwritten to
{
self
.
num_labels
}
."
)
)
self
.
id2label
=
dict
((
int
(
key
)
,
value
)
for
key
,
value
in
self
.
id2label
.
items
()
)
self
.
id2label
=
{
int
(
key
)
:
value
for
key
,
value
in
self
.
id2label
.
items
()
}
# Keys are always strings in JSON so convert ids to int here.
# Keys are always strings in JSON so convert ids to int here.
else
:
else
:
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
...
@@ -696,7 +696,7 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -696,7 +696,7 @@ class PretrainedConfig(PushToHubMixin):
config
=
cls
(
**
config_dict
)
config
=
cls
(
**
config_dict
)
if
hasattr
(
config
,
"pruned_heads"
):
if
hasattr
(
config
,
"pruned_heads"
):
config
.
pruned_heads
=
dict
((
int
(
key
)
,
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
()
)
config
.
pruned_heads
=
{
int
(
key
)
:
value
for
key
,
value
in
config
.
pruned_heads
.
items
()
}
# Update config with kwargs if needed
# Update config with kwargs if needed
if
"num_labels"
in
kwargs
and
"id2label"
in
kwargs
:
if
"num_labels"
in
kwargs
and
"id2label"
in
kwargs
:
...
...
src/transformers/deepspeed.py
View file @
5e8c8eb5
...
@@ -367,13 +367,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
...
@@ -367,13 +367,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
# keep for quick debug:
# keep for quick debug:
# from pprint import pprint; pprint(config)
# from pprint import pprint; pprint(config)
kwargs
=
dict
(
kwargs
=
{
model
=
model
,
"
model
"
:
model
,
model_parameters
=
model_parameters
,
"
model_parameters
"
:
model_parameters
,
config_params
=
config
,
"
config_params
"
:
config
,
optimizer
=
optimizer
,
"
optimizer
"
:
optimizer
,
lr_scheduler
=
lr_scheduler
,
"
lr_scheduler
"
:
lr_scheduler
,
)
}
deepspeed_engine
,
optimizer
,
_
,
lr_scheduler
=
deepspeed
.
initialize
(
**
kwargs
)
deepspeed_engine
,
optimizer
,
_
,
lr_scheduler
=
deepspeed
.
initialize
(
**
kwargs
)
...
...
src/transformers/feature_extraction_sequence_utils.py
View file @
5e8c8eb5
...
@@ -188,7 +188,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
...
@@ -188,7 +188,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
truncated_inputs
=
[]
truncated_inputs
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
inputs
=
dict
((
k
,
v
[
i
]
)
for
k
,
v
in
processed_features
.
items
()
)
inputs
=
{
k
:
v
[
i
]
for
k
,
v
in
processed_features
.
items
()
}
# truncation
# truncation
inputs_slice
=
self
.
_truncate
(
inputs_slice
=
self
.
_truncate
(
inputs
,
inputs
,
...
...
src/transformers/generation/beam_constraints.py
View file @
5e8c8eb5
...
@@ -208,12 +208,12 @@ class DisjunctiveTrie:
...
@@ -208,12 +208,12 @@ class DisjunctiveTrie:
"""
"""
self
.
max_height
=
max
([
len
(
one
)
for
one
in
nested_token_ids
])
self
.
max_height
=
max
([
len
(
one
)
for
one
in
nested_token_ids
])
root
=
dict
()
root
=
{}
for
token_ids
in
nested_token_ids
:
for
token_ids
in
nested_token_ids
:
level
=
root
level
=
root
for
tidx
,
token_id
in
enumerate
(
token_ids
):
for
tidx
,
token_id
in
enumerate
(
token_ids
):
if
token_id
not
in
level
:
if
token_id
not
in
level
:
level
[
token_id
]
=
dict
()
level
[
token_id
]
=
{}
level
=
level
[
token_id
]
level
=
level
[
token_id
]
...
...
src/transformers/generation/logits_process.py
View file @
5e8c8eb5
...
@@ -951,7 +951,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
...
@@ -951,7 +951,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for
k
in
range
(
input_ids
.
shape
[
0
]):
for
k
in
range
(
input_ids
.
shape
[
0
]):
seq
=
[
t
for
t
in
input_ids
[
k
,
self
.
begin_index
:].
tolist
()
]
seq
=
list
(
input_ids
[
k
,
self
.
begin_index
:].
tolist
()
)
last_was_timestamp
=
len
(
seq
)
>=
1
and
seq
[
-
1
]
>=
self
.
timestamp_begin
last_was_timestamp
=
len
(
seq
)
>=
1
and
seq
[
-
1
]
>=
self
.
timestamp_begin
penultimate_was_timestamp
=
len
(
seq
)
<
2
or
seq
[
-
2
]
>=
self
.
timestamp_begin
penultimate_was_timestamp
=
len
(
seq
)
<
2
or
seq
[
-
2
]
>=
self
.
timestamp_begin
...
...
src/transformers/image_utils.py
View file @
5e8c8eb5
...
@@ -115,7 +115,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
...
@@ -115,7 +115,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
if
is_valid_image
(
images
):
if
is_valid_image
(
images
):
if
images
.
ndim
==
expected_ndims
+
1
:
if
images
.
ndim
==
expected_ndims
+
1
:
# Batch of images
# Batch of images
images
=
[
image
for
image
in
images
]
images
=
list
(
images
)
elif
images
.
ndim
==
expected_ndims
:
elif
images
.
ndim
==
expected_ndims
:
# Single image
# Single image
images
=
[
images
]
images
=
[
images
]
...
...
src/transformers/integrations.py
View file @
5e8c8eb5
...
@@ -365,7 +365,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
...
@@ -365,7 +365,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
name
=
"huggingface-tune"
,
name
=
"huggingface-tune"
,
type
=
"offline"
,
type
=
"offline"
,
parameters
=
trainer
.
hp_space
(
None
),
parameters
=
trainer
.
hp_space
(
None
),
metrics
=
[
dict
(
name
=
"objective"
,
objective
=
direction
,
strategy
=
"optimize"
)
],
metrics
=
[
{
"
name
"
:
"objective"
,
"
objective
"
:
direction
,
"
strategy
"
:
"optimize"
}
],
parallel_bandwidth
=
1
,
parallel_bandwidth
=
1
,
budget
=
n_trials
,
budget
=
n_trials
,
)
)
...
@@ -402,7 +402,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
...
@@ -402,7 +402,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
experiment
=
conn
.
experiments
().
create
(
experiment
=
conn
.
experiments
().
create
(
name
=
"huggingface-tune"
,
name
=
"huggingface-tune"
,
parameters
=
trainer
.
hp_space
(
None
),
parameters
=
trainer
.
hp_space
(
None
),
metrics
=
[
dict
(
name
=
"objective"
,
objective
=
direction
,
strategy
=
"optimize"
)
],
metrics
=
[
{
"
name
"
:
"objective"
,
"
objective
"
:
direction
,
"
strategy
"
:
"optimize"
}
],
parallel_bandwidth
=
1
,
parallel_bandwidth
=
1
,
observation_budget
=
n_trials
,
observation_budget
=
n_trials
,
project
=
"huggingface"
,
project
=
"huggingface"
,
...
@@ -425,7 +425,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
...
@@ -425,7 +425,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
metrics
=
trainer
.
evaluate
()
metrics
=
trainer
.
evaluate
()
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
values
=
[
dict
(
name
=
"objective"
,
value
=
trainer
.
objective
)
]
values
=
[
{
"
name
"
:
"objective"
,
"
value
"
:
trainer
.
objective
}
]
obs
=
conn
.
experiments
(
experiment
.
id
).
observations
().
create
(
suggestion
=
suggestion
.
id
,
values
=
values
)
obs
=
conn
.
experiments
(
experiment
.
id
).
observations
().
create
(
suggestion
=
suggestion
.
id
,
values
=
values
)
logger
.
info
(
f
"[suggestion_id, observation_id]: [
{
suggestion
.
id
}
,
{
obs
.
id
}
]"
)
logger
.
info
(
f
"[suggestion_id, observation_id]: [
{
suggestion
.
id
}
,
{
obs
.
id
}
]"
)
experiment
=
conn
.
experiments
(
experiment
.
id
).
fetch
()
experiment
=
conn
.
experiments
(
experiment
.
id
).
fetch
()
...
...
src/transformers/keras_callbacks.py
View file @
5e8c8eb5
...
@@ -162,7 +162,7 @@ class KerasMetricCallback(Callback):
...
@@ -162,7 +162,7 @@ class KerasMetricCallback(Callback):
def
_postprocess_predictions_or_labels
(
self
,
inputs
):
def
_postprocess_predictions_or_labels
(
self
,
inputs
):
if
isinstance
(
inputs
[
0
],
dict
):
if
isinstance
(
inputs
[
0
],
dict
):
outputs
=
dict
()
outputs
=
{}
for
key
in
inputs
[
0
].
keys
():
for
key
in
inputs
[
0
].
keys
():
outputs
[
key
]
=
self
.
_concatenate_batches
([
batch
[
key
]
for
batch
in
inputs
])
outputs
[
key
]
=
self
.
_concatenate_batches
([
batch
[
key
]
for
batch
in
inputs
])
# If it's a dict with only one key, just return the array
# If it's a dict with only one key, just return the array
...
...
src/transformers/modelcard.py
View file @
5e8c8eb5
...
@@ -677,7 +677,7 @@ class TrainingSummary:
...
@@ -677,7 +677,7 @@ class TrainingSummary:
_
,
eval_lines
,
eval_results
=
parse_keras_history
(
keras_history
)
_
,
eval_lines
,
eval_results
=
parse_keras_history
(
keras_history
)
else
:
else
:
eval_lines
=
[]
eval_lines
=
[]
eval_results
=
dict
()
eval_results
=
{}
hyperparameters
=
extract_hyperparameters_from_keras
(
model
)
hyperparameters
=
extract_hyperparameters_from_keras
(
model
)
return
cls
(
return
cls
(
...
@@ -706,7 +706,7 @@ def parse_keras_history(logs):
...
@@ -706,7 +706,7 @@ def parse_keras_history(logs):
# This looks like a `History` object
# This looks like a `History` object
if
not
hasattr
(
logs
,
"epoch"
):
if
not
hasattr
(
logs
,
"epoch"
):
# This history looks empty, return empty results
# This history looks empty, return empty results
return
None
,
[],
dict
()
return
None
,
[],
{}
logs
.
history
[
"epoch"
]
=
logs
.
epoch
logs
.
history
[
"epoch"
]
=
logs
.
epoch
logs
=
logs
.
history
logs
=
logs
.
history
else
:
else
:
...
@@ -716,7 +716,7 @@ def parse_keras_history(logs):
...
@@ -716,7 +716,7 @@ def parse_keras_history(logs):
lines
=
[]
lines
=
[]
for
i
in
range
(
len
(
logs
[
"epoch"
])):
for
i
in
range
(
len
(
logs
[
"epoch"
])):
epoch_dict
=
{
log_key
:
log_value_list
[
i
]
for
log_key
,
log_value_list
in
logs
.
items
()}
epoch_dict
=
{
log_key
:
log_value_list
[
i
]
for
log_key
,
log_value_list
in
logs
.
items
()}
values
=
dict
()
values
=
{}
for
k
,
v
in
epoch_dict
.
items
():
for
k
,
v
in
epoch_dict
.
items
():
if
k
.
startswith
(
"val_"
):
if
k
.
startswith
(
"val_"
):
k
=
"validation_"
+
k
[
4
:]
k
=
"validation_"
+
k
[
4
:]
...
@@ -797,7 +797,7 @@ def parse_log_history(log_history):
...
@@ -797,7 +797,7 @@ def parse_log_history(log_history):
def
extract_hyperparameters_from_keras
(
model
):
def
extract_hyperparameters_from_keras
(
model
):
import
tensorflow
as
tf
import
tensorflow
as
tf
hyperparameters
=
dict
()
hyperparameters
=
{}
if
hasattr
(
model
,
"optimizer"
)
and
model
.
optimizer
is
not
None
:
if
hasattr
(
model
,
"optimizer"
)
and
model
.
optimizer
is
not
None
:
hyperparameters
[
"optimizer"
]
=
model
.
optimizer
.
get_config
()
hyperparameters
[
"optimizer"
]
=
model
.
optimizer
.
get_config
()
else
:
else
:
...
...
src/transformers/modeling_flax_pytorch_utils.py
View file @
5e8c8eb5
...
@@ -76,7 +76,7 @@ def rename_key_and_reshape_tensor(
...
@@ -76,7 +76,7 @@ def rename_key_and_reshape_tensor(
def
is_key_or_prefix_key_in_dict
(
key
:
Tuple
[
str
])
->
bool
:
def
is_key_or_prefix_key_in_dict
(
key
:
Tuple
[
str
])
->
bool
:
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
return
len
(
set
(
random_flax_state_dict
)
&
set
([
key
,
(
model_prefix
,)
+
key
])
)
>
0
return
len
(
set
(
random_flax_state_dict
)
&
{
key
,
(
model_prefix
,)
+
key
}
)
>
0
# layer norm
# layer norm
renamed_pt_tuple_key
=
pt_tuple_key
[:
-
1
]
+
(
"scale"
,)
renamed_pt_tuple_key
=
pt_tuple_key
[:
-
1
]
+
(
"scale"
,)
...
@@ -122,10 +122,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
...
@@ -122,10 +122,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict
=
{}
flax_state_dict
=
{}
load_model_with_head_into_base_model
=
(
model_prefix
not
in
flax_model
.
params
)
and
(
load_model_with_head_into_base_model
=
(
model_prefix
not
in
flax_model
.
params
)
and
(
model_prefix
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
)
load_base_model_into_model_with_head
=
(
model_prefix
in
flax_model
.
params
)
and
(
load_base_model_into_model_with_head
=
(
model_prefix
in
flax_model
.
params
)
and
(
model_prefix
not
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
not
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
)
# Need to change some parameters name to match Flax names
# Need to change some parameters name to match Flax names
...
@@ -179,10 +179,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
...
@@ -179,10 +179,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
random_flax_state_dict
=
flatten_dict
(
flax_model
.
params
)
random_flax_state_dict
=
flatten_dict
(
flax_model
.
params
)
load_model_with_head_into_base_model
=
(
model_prefix
not
in
flax_model
.
params
)
and
(
load_model_with_head_into_base_model
=
(
model_prefix
not
in
flax_model
.
params
)
and
(
model_prefix
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
)
load_base_model_into_model_with_head
=
(
model_prefix
in
flax_model
.
params
)
and
(
load_base_model_into_model_with_head
=
(
model_prefix
in
flax_model
.
params
)
and
(
model_prefix
not
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
])
model_prefix
not
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_state_dict
.
keys
()
}
)
)
# Need to change some parameters name to match Flax names
# Need to change some parameters name to match Flax names
for
pt_key
,
pt_tensor
in
pt_state_dict
.
items
():
for
pt_key
,
pt_tensor
in
pt_state_dict
.
items
():
...
@@ -267,10 +267,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
...
@@ -267,10 +267,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
pt_model_dict
=
pt_model
.
state_dict
()
pt_model_dict
=
pt_model
.
state_dict
()
load_model_with_head_into_base_model
=
(
pt_model
.
base_model_prefix
in
flax_state
)
and
(
load_model_with_head_into_base_model
=
(
pt_model
.
base_model_prefix
in
flax_state
)
and
(
pt_model
.
base_model_prefix
not
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
])
pt_model
.
base_model_prefix
not
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
}
)
)
load_base_model_into_model_with_head
=
(
pt_model
.
base_model_prefix
not
in
flax_state
)
and
(
load_base_model_into_model_with_head
=
(
pt_model
.
base_model_prefix
not
in
flax_state
)
and
(
pt_model
.
base_model_prefix
in
set
([
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
])
pt_model
.
base_model_prefix
in
{
k
.
split
(
"."
)[
0
]
for
k
in
pt_model_dict
.
keys
()
}
)
)
# keep track of unexpected & missing keys
# keep track of unexpected & missing keys
...
...
src/transformers/modeling_flax_utils.py
View file @
5e8c8eb5
...
@@ -440,7 +440,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -440,7 +440,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"""
"""
# Load the index
# Load the index
state_sharded_dict
=
dict
()
state_sharded_dict
=
{}
for
shard_file
in
shard_files
:
for
shard_file
in
shard_files
:
# load using msgpack utils
# load using msgpack utils
...
@@ -708,19 +708,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -708,19 +708,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
filename
=
WEIGHTS_NAME
if
from_pt
else
FLAX_WEIGHTS_NAME
filename
=
WEIGHTS_NAME
if
from_pt
else
FLAX_WEIGHTS_NAME
try
:
try
:
# Load from URL or cache if already cached
# Load from URL or cache if already cached
cached_file_kwargs
=
dict
(
cached_file_kwargs
=
{
cache_dir
=
cache_dir
,
"
cache_dir
"
:
cache_dir
,
force_download
=
force_download
,
"
force_download
"
:
force_download
,
proxies
=
proxies
,
"
proxies
"
:
proxies
,
resume_download
=
resume_download
,
"
resume_download
"
:
resume_download
,
local_files_only
=
local_files_only
,
"
local_files_only
"
:
local_files_only
,
use_auth_token
=
use_auth_token
,
"
use_auth_token
"
:
use_auth_token
,
user_agent
=
user_agent
,
"
user_agent
"
:
user_agent
,
revision
=
revision
,
"
revision
"
:
revision
,
subfolder
=
subfolder
,
"
subfolder
"
:
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
"
_raise_exceptions_for_missing_entries
"
:
False
,
_commit_hash
=
commit_hash
,
"
_commit_hash
"
:
commit_hash
,
)
}
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
...
...
src/transformers/modeling_tf_pytorch_utils.py
View file @
5e8c8eb5
...
@@ -258,7 +258,7 @@ def load_pytorch_state_dict_in_tf2_model(
...
@@ -258,7 +258,7 @@ def load_pytorch_state_dict_in_tf2_model(
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
tf_loaded_numel
=
0
tf_loaded_numel
=
0
weight_value_tuples
=
[]
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
())
)
all_pytorch_weights
=
set
(
pt_state_dict
.
keys
())
missing_keys
=
[]
missing_keys
=
[]
for
symbolic_weight
in
symbolic_weights
:
for
symbolic_weight
in
symbolic_weights
:
sw_name
=
symbolic_weight
.
name
sw_name
=
symbolic_weight
.
name
...
@@ -425,7 +425,7 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
...
@@ -425,7 +425,7 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
)
)
tf_weights_map
[
pt_name
]
=
(
tf_weight
,
transpose
)
tf_weights_map
[
pt_name
]
=
(
tf_weight
,
transpose
)
all_tf_weights
=
set
(
list
(
tf_weights_map
.
keys
())
)
all_tf_weights
=
set
(
tf_weights_map
.
keys
())
loaded_pt_weights_data_ptr
=
{}
loaded_pt_weights_data_ptr
=
{}
missing_keys_pt
=
[]
missing_keys_pt
=
[]
for
pt_weight_name
,
pt_weight
in
current_pt_params_dict
.
items
():
for
pt_weight_name
,
pt_weight
in
current_pt_params_dict
.
items
():
...
...
src/transformers/modeling_tf_utils.py
View file @
5e8c8eb5
...
@@ -584,7 +584,7 @@ def input_processing(func, config, **kwargs):
...
@@ -584,7 +584,7 @@ def input_processing(func, config, **kwargs):
if
"kwargs"
in
output
:
if
"kwargs"
in
output
:
del
output
[
"kwargs"
]
del
output
[
"kwargs"
]
cast_output
=
dict
()
cast_output
=
{}
for
key
,
val
in
output
.
items
():
for
key
,
val
in
output
.
items
():
if
isinstance
(
val
,
tf
.
Tensor
)
and
val
.
dtype
==
tf
.
int64
:
if
isinstance
(
val
,
tf
.
Tensor
)
and
val
.
dtype
==
tf
.
int64
:
cast_output
[
key
]
=
tf
.
cast
(
val
,
tf
.
int32
)
cast_output
[
key
]
=
tf
.
cast
(
val
,
tf
.
int32
)
...
@@ -737,7 +737,7 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
...
@@ -737,7 +737,7 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer.
# the weight, we have to get rid of the first prefix of the name of the layer.
model_keys
=
set
()
model_keys
=
set
()
model_layer_map
=
dict
()
model_layer_map
=
{}
for
i
,
k
in
enumerate
(
model
.
weights
):
for
i
,
k
in
enumerate
(
model
.
weights
):
if
"model."
in
k
.
name
or
len
(
k
.
name
.
split
(
"/"
))
==
1
:
if
"model."
in
k
.
name
or
len
(
k
.
name
.
split
(
"/"
))
==
1
:
layer_name
=
k
.
name
layer_name
=
k
.
name
...
@@ -901,10 +901,10 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
...
@@ -901,10 +901,10 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
)
)
# Find the missing layers from the high level list of layers
# Find the missing layers from the high level list of layers
missing_layers
=
list
(
set
([
layer
.
name
for
layer
in
model
.
layers
])
-
saved_h5_model_layers_name
)
missing_layers
=
list
(
{
layer
.
name
for
layer
in
model
.
layers
}
-
saved_h5_model_layers_name
)
# Find the unexpected layers from the high level list of layers
# Find the unexpected layers from the high level list of layers
unexpected_layers
=
list
(
saved_h5_model_layers_name
-
set
([
layer
.
name
for
layer
in
model
.
layers
])
)
unexpected_layers
=
list
(
saved_h5_model_layers_name
-
{
layer
.
name
for
layer
in
model
.
layers
}
)
saved_weight_names_set
=
set
()
saved_weight_names_set
=
set
()
symbolic_weights_names
=
set
()
symbolic_weights_names
=
set
()
weight_value_tuples
=
[]
weight_value_tuples
=
[]
...
@@ -1349,7 +1349,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1349,7 +1349,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else
:
else
:
collate_fn
=
DataCollatorWithPadding
(
tokenizer
=
tokenizer
,
return_tensors
=
"np"
)
collate_fn
=
DataCollatorWithPadding
(
tokenizer
=
tokenizer
,
return_tensors
=
"np"
)
if
collate_fn_args
is
None
:
if
collate_fn_args
is
None
:
collate_fn_args
=
dict
()
collate_fn_args
=
{}
if
not
isinstance
(
dataset
,
datasets
.
Dataset
):
if
not
isinstance
(
dataset
,
datasets
.
Dataset
):
raise
TypeError
(
"Dataset argument should be a datasets.Dataset!"
)
raise
TypeError
(
"Dataset argument should be a datasets.Dataset!"
)
...
@@ -1471,7 +1471,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1471,7 +1471,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif
"mc_labels"
in
arg_names
:
elif
"mc_labels"
in
arg_names
:
return
{
"labels"
:
"logits"
,
"mc_labels"
:
"mc_logits"
}
return
{
"labels"
:
"logits"
,
"mc_labels"
:
"mc_logits"
}
else
:
else
:
return
dict
()
return
{}
def
train_step
(
self
,
data
):
def
train_step
(
self
,
data
):
"""
"""
...
@@ -2613,19 +2613,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -2613,19 +2613,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try
:
try
:
# Load from URL or cache if already cached
# Load from URL or cache if already cached
cached_file_kwargs
=
dict
(
cached_file_kwargs
=
{
cache_dir
=
cache_dir
,
"
cache_dir
"
:
cache_dir
,
force_download
=
force_download
,
"
force_download
"
:
force_download
,
proxies
=
proxies
,
"
proxies
"
:
proxies
,
resume_download
=
resume_download
,
"
resume_download
"
:
resume_download
,
local_files_only
=
local_files_only
,
"
local_files_only
"
:
local_files_only
,
use_auth_token
=
use_auth_token
,
"
use_auth_token
"
:
use_auth_token
,
user_agent
=
user_agent
,
"
user_agent
"
:
user_agent
,
revision
=
revision
,
"
revision
"
:
revision
,
subfolder
=
subfolder
,
"
subfolder
"
:
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
"
_raise_exceptions_for_missing_entries
"
:
False
,
_commit_hash
=
commit_hash
,
"
_commit_hash
"
:
commit_hash
,
)
}
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
...
...
src/transformers/modeling_utils.py
View file @
5e8c8eb5
...
@@ -1271,7 +1271,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1271,7 +1271,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
len
(
encoder_modules
)
>
0
len
(
encoder_modules
)
>
0
),
f
"Encoder module
{
encoder_pointer
}
does not match decoder module
{
decoder_pointer
}
"
),
f
"Encoder module
{
encoder_pointer
}
does not match decoder module
{
decoder_pointer
}
"
all_encoder_weights
=
set
([
module_name
+
"/"
+
sub_name
for
sub_name
in
encoder_modules
.
keys
()
])
all_encoder_weights
=
{
module_name
+
"/"
+
sub_name
for
sub_name
in
encoder_modules
.
keys
()
}
encoder_layer_pos
=
0
encoder_layer_pos
=
0
for
name
,
module
in
decoder_modules
.
items
():
for
name
,
module
in
decoder_modules
.
items
():
if
name
.
isdigit
():
if
name
.
isdigit
():
...
@@ -2304,19 +2304,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2304,19 +2304,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
try
:
try
:
# Load from URL or cache if already cached
# Load from URL or cache if already cached
cached_file_kwargs
=
dict
(
cached_file_kwargs
=
{
cache_dir
=
cache_dir
,
"
cache_dir
"
:
cache_dir
,
force_download
=
force_download
,
"
force_download
"
:
force_download
,
proxies
=
proxies
,
"
proxies
"
:
proxies
,
resume_download
=
resume_download
,
"
resume_download
"
:
resume_download
,
local_files_only
=
local_files_only
,
"
local_files_only
"
:
local_files_only
,
use_auth_token
=
use_auth_token
,
"
use_auth_token
"
:
use_auth_token
,
user_agent
=
user_agent
,
"
user_agent
"
:
user_agent
,
revision
=
revision
,
"
revision
"
:
revision
,
subfolder
=
subfolder
,
"
subfolder
"
:
subfolder
,
_raise_exceptions_for_missing_entries
=
False
,
"
_raise_exceptions_for_missing_entries
"
:
False
,
_commit_hash
=
commit_hash
,
"
_commit_hash
"
:
commit_hash
,
)
}
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
resolved_archive_file
=
cached_file
(
pretrained_model_name_or_path
,
filename
,
**
cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
...
@@ -2474,7 +2474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2474,7 +2474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
is_sharded
:
if
is_sharded
:
loaded_state_dict_keys
=
sharded_metadata
[
"all_checkpoint_keys"
]
loaded_state_dict_keys
=
sharded_metadata
[
"all_checkpoint_keys"
]
else
:
else
:
loaded_state_dict_keys
=
[
k
for
k
in
state_dict
.
keys
()
]
loaded_state_dict_keys
=
list
(
state_dict
.
keys
()
)
if
low_cpu_mem_usage
or
use_keep_in_fp32_modules
:
if
low_cpu_mem_usage
or
use_keep_in_fp32_modules
:
state_dict
=
None
state_dict
=
None
...
@@ -3046,12 +3046,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -3046,12 +3046,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
offload_index
,
error_msgs
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
offload_index
,
error_msgs
def
retrieve_modules_from_names
(
self
,
names
,
add_prefix
=
False
,
remove_prefix
=
False
):
def
retrieve_modules_from_names
(
self
,
names
,
add_prefix
=
False
,
remove_prefix
=
False
):
module_keys
=
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
for
key
in
names
])
module_keys
=
{
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
for
key
in
names
}
# torch.nn.ParameterList is a special case where two parameter keywords
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys
=
module_keys
.
union
(
module_keys
=
module_keys
.
union
(
set
([
"."
.
join
(
key
.
split
(
"."
)[:
-
2
])
for
key
in
names
if
len
(
key
)
>
0
and
key
[
-
1
].
isdigit
()
])
{
"."
.
join
(
key
.
split
(
"."
)[:
-
2
])
for
key
in
names
if
len
(
key
)
>
0
and
key
[
-
1
].
isdigit
()
}
)
)
retrieved_modules
=
[]
retrieved_modules
=
[]
...
...
src/transformers/models/beit/modeling_flax_beit.py
View file @
5e8c8eb5
...
@@ -555,7 +555,7 @@ class FlaxBeitEncoder(nn.Module):
...
@@ -555,7 +555,7 @@ class FlaxBeitEncoder(nn.Module):
)
)
# stochastic depth decay rule
# stochastic depth decay rule
drop_path_rates
=
[
x
for
x
in
np
.
linspace
(
0
,
self
.
config
.
drop_path_rate
,
self
.
config
.
num_hidden_layers
)
]
drop_path_rates
=
list
(
np
.
linspace
(
0
,
self
.
config
.
drop_path_rate
,
self
.
config
.
num_hidden_layers
)
)
self
.
layer
=
FlaxBeitLayerCollection
(
self
.
layer
=
FlaxBeitLayerCollection
(
self
.
config
,
self
.
config
,
window_size
=
self
.
window_size
,
window_size
=
self
.
window_size
,
...
...
src/transformers/models/bertweet/tokenization_bertweet.py
View file @
5e8c8eb5
...
@@ -318,7 +318,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
...
@@ -318,7 +318,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
split_tokens
=
[]
split_tokens
=
[]
words
=
re
.
findall
(
r
"\S+\n?"
,
text
)
words
=
re
.
findall
(
r
"\S+\n?"
,
text
)
for
token
in
words
:
for
token
in
words
:
split_tokens
.
extend
(
[
t
for
t
in
self
.
bpe
(
token
).
split
(
" "
)
]
)
split_tokens
.
extend
(
list
(
self
.
bpe
(
token
).
split
(
" "
)
)
)
return
split_tokens
return
split_tokens
def
normalizeTweet
(
self
,
tweet
):
def
normalizeTweet
(
self
,
tweet
):
...
@@ -726,7 +726,7 @@ class TweetTokenizer:
...
@@ -726,7 +726,7 @@ class TweetTokenizer:
words
=
WORD_RE
.
findall
(
safe_text
)
words
=
WORD_RE
.
findall
(
safe_text
)
# Possibly alter the case, but avoid changing emoticons like :D into :d:
# Possibly alter the case, but avoid changing emoticons like :D into :d:
if
not
self
.
preserve_case
:
if
not
self
.
preserve_case
:
words
=
list
(
map
((
lambda
x
:
x
if
EMOTICON_RE
.
search
(
x
)
else
x
.
lower
()
),
words
))
words
=
[
x
if
EMOTICON_RE
.
search
(
x
)
else
x
.
lower
()
for
x
in
words
]
return
words
return
words
...
...
src/transformers/models/big_bird/tokenization_big_bird_fast.py
View file @
5e8c8eb5
...
@@ -202,7 +202,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
...
@@ -202,7 +202,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
"You should not supply a second sequence if the provided sequence of "
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
"ids is already formatted with special tokens for the model."
)
)
return
list
(
map
(
lambda
x
:
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
,
token_ids_0
))
return
[
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
for
x
in
token_ids_0
]
if
token_ids_1
is
None
:
if
token_ids_1
is
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
...
...
src/transformers/models/biogpt/tokenization_biogpt.py
View file @
5e8c8eb5
...
@@ -132,8 +132,8 @@ class BioGptTokenizer(PreTrainedTokenizer):
...
@@ -132,8 +132,8 @@ class BioGptTokenizer(PreTrainedTokenizer):
self
.
lang
=
"en"
self
.
lang
=
"en"
self
.
sm
=
sacremoses
self
.
sm
=
sacremoses
# cache of sm.MosesTokenizer instance
# cache of sm.MosesTokenizer instance
self
.
cache_moses_tokenizer
=
dict
()
self
.
cache_moses_tokenizer
=
{}
self
.
cache_moses_detokenizer
=
dict
()
self
.
cache_moses_detokenizer
=
{}
""" Initialisation"""
""" Initialisation"""
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
...
@@ -221,7 +221,7 @@ class BioGptTokenizer(PreTrainedTokenizer):
...
@@ -221,7 +221,7 @@ class BioGptTokenizer(PreTrainedTokenizer):
split_tokens
=
[]
split_tokens
=
[]
for
token
in
text
:
for
token
in
text
:
if
token
:
if
token
:
split_tokens
.
extend
(
[
t
for
t
in
self
.
bpe
(
token
).
split
(
" "
)
]
)
split_tokens
.
extend
(
list
(
self
.
bpe
(
token
).
split
(
" "
)
)
)
return
split_tokens
return
split_tokens
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment