Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de231889
Unverified
Commit
de231889
authored
Jul 25, 2024
by
Kashif Rasul
Committed by
GitHub
Jul 25, 2024
Browse files
[warnings] fix E721 warnings (#32223)
fix E721 warnings
parent
9b9a54e6
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
9 additions
and
9 deletions
+9
-9
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+1
-1
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+1
-1
src/transformers/models/esm/openfold_utils/chunk_utils.py
src/transformers/models/esm/openfold_utils/chunk_utils.py
+1
-1
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+1
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+1
-1
src/transformers/utils/chat_template_utils.py
src/transformers/utils/chat_template_utils.py
+1
-1
src/transformers/utils/generic.py
src/transformers/utils/generic.py
+1
-1
tests/models/ibert/test_modeling_ibert.py
tests/models/ibert/test_modeling_ibert.py
+2
-2
No files found.
src/transformers/generation/candidate_generator.py
View file @
de231889
...
@@ -162,7 +162,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -162,7 +162,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
self
.
generation_config
.
min_length
=
0
self
.
generation_config
.
min_length
=
0
self
.
generation_config
.
min_new_tokens
=
None
self
.
generation_config
.
min_new_tokens
=
None
for
processor
in
self
.
logits_processor
:
for
processor
in
self
.
logits_processor
:
if
typ
e
(
processor
)
==
MinLengthLogitsProcessor
:
if
isinstanc
e
(
processor
,
MinLengthLogitsProcessor
)
:
raise
ValueError
(
raise
ValueError
(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
"Please pass in `min_length` into `.generate()` instead"
...
...
src/transformers/models/bart/modeling_flax_bart.py
View file @
de231889
...
@@ -1599,7 +1599,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module):
...
@@ -1599,7 +1599,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module):
eos_mask
=
jnp
.
where
(
input_ids
==
self
.
config
.
eos_token_id
,
1
,
0
)
eos_mask
=
jnp
.
where
(
input_ids
==
self
.
config
.
eos_token_id
,
1
,
0
)
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if
typ
e
(
eos_mask
)
!=
jax
.
interpreters
.
partial_eval
.
DynamicJaxprTracer
:
if
not
isinstanc
e
(
eos_mask
,
jax
.
interpreters
.
partial_eval
.
DynamicJaxprTracer
)
:
if
len
(
jnp
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
if
len
(
jnp
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
...
...
src/transformers/models/esm/openfold_utils/chunk_utils.py
View file @
de231889
...
@@ -356,7 +356,7 @@ class ChunkSizeTuner:
...
@@ -356,7 +356,7 @@ class ChunkSizeTuner:
def
_compare_arg_caches
(
self
,
ac1
:
Iterable
,
ac2
:
Iterable
)
->
bool
:
def
_compare_arg_caches
(
self
,
ac1
:
Iterable
,
ac2
:
Iterable
)
->
bool
:
consistent
=
True
consistent
=
True
for
a1
,
a2
in
zip
(
ac1
,
ac2
):
for
a1
,
a2
in
zip
(
ac1
,
ac2
):
assert
type
(
ac1
)
==
type
(
ac2
)
assert
type
(
ac1
)
is
type
(
ac2
)
if
isinstance
(
ac1
,
(
list
,
tuple
)):
if
isinstance
(
ac1
,
(
list
,
tuple
)):
consistent
&=
self
.
_compare_arg_caches
(
a1
,
a2
)
consistent
&=
self
.
_compare_arg_caches
(
a1
,
a2
)
elif
isinstance
(
ac1
,
dict
):
elif
isinstance
(
ac1
,
dict
):
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
de231889
...
@@ -1635,7 +1635,7 @@ class FlaxMBartForSequenceClassificationModule(nn.Module):
...
@@ -1635,7 +1635,7 @@ class FlaxMBartForSequenceClassificationModule(nn.Module):
eos_mask
=
jnp
.
where
(
input_ids
==
self
.
config
.
eos_token_id
,
1
,
0
)
eos_mask
=
jnp
.
where
(
input_ids
==
self
.
config
.
eos_token_id
,
1
,
0
)
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if
typ
e
(
eos_mask
)
!=
jax
.
interpreters
.
partial_eval
.
DynamicJaxprTracer
:
if
not
isinstanc
e
(
eos_mask
,
jax
.
interpreters
.
partial_eval
.
DynamicJaxprTracer
)
:
if
len
(
jnp
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
if
len
(
jnp
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
...
...
src/transformers/trainer_pt_utils.py
View file @
de231889
...
@@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
...
@@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
"""
"""
if
not
(
isinstance
(
tensors
,
torch
.
Tensor
)
and
isinstance
(
new_tensors
,
torch
.
Tensor
)):
if
not
(
isinstance
(
tensors
,
torch
.
Tensor
)
and
isinstance
(
new_tensors
,
torch
.
Tensor
)):
assert
(
assert
(
type
(
tensors
)
==
type
(
new_tensors
)
type
(
tensors
)
is
type
(
new_tensors
)
),
f
"Expected `tensors` and `new_tensors` to have the same type but found
{
type
(
tensors
)
}
and
{
type
(
new_tensors
)
}
."
),
f
"Expected `tensors` and `new_tensors` to have the same type but found
{
type
(
tensors
)
}
and
{
type
(
new_tensors
)
}
."
if
isinstance
(
tensors
,
(
list
,
tuple
)):
if
isinstance
(
tensors
,
(
list
,
tuple
)):
return
type
(
tensors
)(
nested_concat
(
t
,
n
,
padding_index
=
padding_index
)
for
t
,
n
in
zip
(
tensors
,
new_tensors
))
return
type
(
tensors
)(
nested_concat
(
t
,
n
,
padding_index
=
padding_index
)
for
t
,
n
in
zip
(
tensors
,
new_tensors
))
...
...
src/transformers/utils/chat_template_utils.py
View file @
de231889
...
@@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict:
...
@@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict:
elif
origin
is
Union
:
elif
origin
is
Union
:
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
subtypes
=
[
_parse_type_hint
(
t
)
for
t
in
args
if
t
!=
type
(
None
)]
subtypes
=
[
_parse_type_hint
(
t
)
for
t
in
args
if
t
is
not
type
(
None
)]
if
len
(
subtypes
)
==
1
:
if
len
(
subtypes
)
==
1
:
# A single non-null type can be expressed directly
# A single non-null type can be expressed directly
return_dict
=
subtypes
[
0
]
return_dict
=
subtypes
[
0
]
...
...
src/transformers/utils/generic.py
View file @
de231889
...
@@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x):
...
@@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x):
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
if
hasattr
(
tf
,
"is_symbolic_tensor"
):
if
hasattr
(
tf
,
"is_symbolic_tensor"
):
return
tf
.
is_symbolic_tensor
(
x
)
return
tf
.
is_symbolic_tensor
(
x
)
return
type
(
x
)
==
tf
.
Tensor
return
isinstance
(
x
,
tf
.
Tensor
)
def
is_tf_symbolic_tensor
(
x
):
def
is_tf_symbolic_tensor
(
x
):
...
...
tests/models/ibert/test_modeling_ibert.py
View file @
de231889
...
@@ -684,10 +684,10 @@ class IBertModelIntegrationTest(unittest.TestCase):
...
@@ -684,10 +684,10 @@ class IBertModelIntegrationTest(unittest.TestCase):
# Recursively convert all the `quant_mode` attributes as `True`
# Recursively convert all the `quant_mode` attributes as `True`
if
hasattr
(
model
,
"quant_mode"
):
if
hasattr
(
model
,
"quant_mode"
):
model
.
quant_mode
=
True
model
.
quant_mode
=
True
elif
typ
e
(
model
)
==
nn
.
Sequential
:
elif
isinstanc
e
(
model
,
nn
.
Sequential
)
:
for
n
,
m
in
model
.
named_children
():
for
n
,
m
in
model
.
named_children
():
self
.
quantize
(
m
)
self
.
quantize
(
m
)
elif
typ
e
(
model
)
==
nn
.
ModuleList
:
elif
isinstanc
e
(
model
,
nn
.
ModuleList
)
:
for
n
in
model
:
for
n
in
model
:
self
.
quantize
(
n
)
self
.
quantize
(
n
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment