"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "24d5ad1dcceee736ac829ec316fa3320e4df0064"
Unverified Commit de231889 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[warnings] fix E721 warnings (#32223)

fix E721 warnings
parent 9b9a54e6
...@@ -162,7 +162,7 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -162,7 +162,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.generation_config.min_length = 0 self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None self.generation_config.min_new_tokens = None
for processor in self.logits_processor: for processor in self.logits_processor:
if type(processor) == MinLengthLogitsProcessor: if isinstance(processor, MinLengthLogitsProcessor):
raise ValueError( raise ValueError(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. " "Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead" "Please pass in `min_length` into `.generate()` instead"
......
...@@ -1599,7 +1599,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module): ...@@ -1599,7 +1599,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module):
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
if len(jnp.unique(eos_mask.sum(1))) > 1: if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
......
...@@ -356,7 +356,7 @@ class ChunkSizeTuner: ...@@ -356,7 +356,7 @@ class ChunkSizeTuner:
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool: def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
consistent = True consistent = True
for a1, a2 in zip(ac1, ac2): for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2) assert type(ac1) is type(ac2)
if isinstance(ac1, (list, tuple)): if isinstance(ac1, (list, tuple)):
consistent &= self._compare_arg_caches(a1, a2) consistent &= self._compare_arg_caches(a1, a2)
elif isinstance(ac1, dict): elif isinstance(ac1, dict):
......
...@@ -1635,7 +1635,7 @@ class FlaxMBartForSequenceClassificationModule(nn.Module): ...@@ -1635,7 +1635,7 @@ class FlaxMBartForSequenceClassificationModule(nn.Module):
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
if len(jnp.unique(eos_mask.sum(1))) > 1: if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
......
...@@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100): ...@@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
""" """
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)): if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
assert ( assert (
type(tensors) == type(new_tensors) type(tensors) is type(new_tensors)
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors)) return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
......
...@@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict: ...@@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict:
elif origin is Union: elif origin is Union:
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
subtypes = [_parse_type_hint(t) for t in args if t != type(None)] subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
if len(subtypes) == 1: if len(subtypes) == 1:
# A single non-null type can be expressed directly # A single non-null type can be expressed directly
return_dict = subtypes[0] return_dict = subtypes[0]
......
...@@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x): ...@@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x):
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14 # the `is_symbolic_tensor` predicate is only available starting with TF 2.14
if hasattr(tf, "is_symbolic_tensor"): if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x) return tf.is_symbolic_tensor(x)
return type(x) == tf.Tensor return isinstance(x, tf.Tensor)
def is_tf_symbolic_tensor(x): def is_tf_symbolic_tensor(x):
......
...@@ -684,10 +684,10 @@ class IBertModelIntegrationTest(unittest.TestCase): ...@@ -684,10 +684,10 @@ class IBertModelIntegrationTest(unittest.TestCase):
# Recursively convert all the `quant_mode` attributes as `True` # Recursively convert all the `quant_mode` attributes as `True`
if hasattr(model, "quant_mode"): if hasattr(model, "quant_mode"):
model.quant_mode = True model.quant_mode = True
elif type(model) == nn.Sequential: elif isinstance(model, nn.Sequential):
for n, m in model.named_children(): for n, m in model.named_children():
self.quantize(m) self.quantize(m)
elif type(model) == nn.ModuleList: elif isinstance(model, nn.ModuleList):
for n in model: for n in model:
self.quantize(n) self.quantize(n)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment