Unverified Commit d2c687b3 authored by Sai-Suraj-27's avatar Sai-Suraj-27 Committed by GitHub
Browse files

Updated `ruff` to the latest version (#31926)

* Updated ruff version and fixed the required code accorindg to the latest version.

* Updated ruff version and fixed the required code accorindg to the latest version.

* Added noqa directive to ignore 1 error shown by ruff
parent 9cf4f2aa
...@@ -418,7 +418,7 @@ class TestTheRest(TestCasePlus): ...@@ -418,7 +418,7 @@ class TestTheRest(TestCasePlus):
with CaptureStdout() as cs: with CaptureStdout() as cs:
args = parser.parse_args(args) args = parser.parse_args(args)
assert False, "--help is expected to sys.exit" assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit assert excinfo.type is SystemExit
expected = lightning_base.arg_to_scheduler_metavar expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers" assert expected in cs.out, "--help is expected to list the supported schedulers"
...@@ -429,7 +429,7 @@ class TestTheRest(TestCasePlus): ...@@ -429,7 +429,7 @@ class TestTheRest(TestCasePlus):
with CaptureStderr() as cs: with CaptureStderr() as cs:
args = parser.parse_args(args) args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit" assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit assert excinfo.type is SystemExit
expected = f"invalid choice: '{unsupported_param}'" expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
......
...@@ -157,7 +157,7 @@ _deps = [ ...@@ -157,7 +157,7 @@ _deps = [
"rhoknp>=1.1.0,<1.3.1", "rhoknp>=1.1.0,<1.3.1",
"rjieba", "rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff==0.4.4", "ruff==0.5.1",
"sacrebleu>=1.4.12,<2.0.0", "sacrebleu>=1.4.12,<2.0.0",
"sacremoses", "sacremoses",
"safetensors>=0.4.1", "safetensors>=0.4.1",
......
...@@ -63,7 +63,7 @@ deps = { ...@@ -63,7 +63,7 @@ deps = {
"rhoknp": "rhoknp>=1.1.0,<1.3.1", "rhoknp": "rhoknp>=1.1.0,<1.3.1",
"rjieba": "rjieba", "rjieba": "rjieba",
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff": "ruff==0.4.4", "ruff": "ruff==0.5.1",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses", "sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.1", "safetensors": "safetensors>=0.4.1",
......
...@@ -164,7 +164,7 @@ class HfArgumentParser(ArgumentParser): ...@@ -164,7 +164,7 @@ class HfArgumentParser(ArgumentParser):
) )
if type(None) not in field.type.__args__: if type(None) not in field.type.__args__:
# filter `str` in Union # filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type) origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__: elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`) # filter `NoneType` in Union (except for `Union[bool, NoneType]`)
......
...@@ -90,7 +90,7 @@ def dtype_byte_size(dtype): ...@@ -90,7 +90,7 @@ def dtype_byte_size(dtype):
4 4
``` ```
""" """
if dtype == bool: if dtype is bool:
return 1 / 8 return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name) bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None: if bit_search is None:
......
...@@ -398,7 +398,7 @@ class TransformerBlock(nn.Module): ...@@ -398,7 +398,7 @@ class TransformerBlock(nn.Module):
if output_attentions: if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
if type(sa_output) != tuple: if type(sa_output) is not tuple:
raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type") raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type")
sa_output = sa_output[0] sa_output = sa_output[0]
......
...@@ -304,7 +304,7 @@ class FlaxTransformerBlock(nn.Module): ...@@ -304,7 +304,7 @@ class FlaxTransformerBlock(nn.Module):
if output_attentions: if output_attentions:
sa_output, sa_weights = sa_output sa_output, sa_weights = sa_output
else: else:
assert type(sa_output) == tuple assert type(sa_output) is tuple
sa_output = sa_output[0] sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + hidden_states) sa_output = self.sa_layer_norm(sa_output + hidden_states)
......
...@@ -343,7 +343,7 @@ class Rotation: ...@@ -343,7 +343,7 @@ class Rotation:
Returns: Returns:
The indexed rotation The indexed rotation
""" """
if type(index) != tuple: if type(index) is not tuple:
index = (index,) index = (index,)
if self._rot_mats is not None: if self._rot_mats is not None:
...@@ -827,7 +827,7 @@ class Rigid: ...@@ -827,7 +827,7 @@ class Rigid:
Returns: Returns:
The indexed tensor The indexed tensor
""" """
if type(index) != tuple: if type(index) is not tuple:
index = (index,) index = (index,)
return Rigid( return Rigid(
......
...@@ -68,7 +68,7 @@ class MarkupLMFeatureExtractor(FeatureExtractionMixin): ...@@ -68,7 +68,7 @@ class MarkupLMFeatureExtractor(FeatureExtractionMixin):
for element in html_code.descendants: for element in html_code.descendants:
if isinstance(element, bs4.element.NavigableString): if isinstance(element, bs4.element.NavigableString):
if type(element.parent) != bs4.element.Tag: if type(element.parent) is not bs4.element.Tag:
continue continue
text_in_this_tag = html.unescape(element).strip() text_in_this_tag = html.unescape(element).strip()
......
...@@ -2550,7 +2550,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2550,7 +2550,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
generation_config.validate() generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple: if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
# wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
......
...@@ -254,7 +254,7 @@ def reissue_pt_warnings(caught_warnings): ...@@ -254,7 +254,7 @@ def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the SAVE_STATE_WARNING # Reissue warnings that are not the SAVE_STATE_WARNING
if len(caught_warnings) > 1: if len(caught_warnings) > 1:
for w in caught_warnings: for w in caught_warnings:
if w.category != UserWarning or w.message != SAVE_STATE_WARNING: if w.category is not UserWarning or w.message != SAVE_STATE_WARNING:
warnings.warn(w.message, w.category) warnings.warn(w.message, w.category)
......
...@@ -198,7 +198,7 @@ Action: ...@@ -198,7 +198,7 @@ Action:
) )
agent.run("What is 2 multiplied by 3.6452?") agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 7 assert len(agent.logs) == 7
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError
@require_torch @require_torch
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
......
...@@ -214,7 +214,7 @@ recur_fibo(6)""" ...@@ -214,7 +214,7 @@ recur_fibo(6)"""
def test_access_attributes(self): def test_access_attributes(self):
code = "integer = 1\nobj_class = integer.__class__\nobj_class" code = "integer = 1\nobj_class = integer.__class__\nobj_class"
result = evaluate_python_code(code, {}, state={}) result = evaluate_python_code(code, {}, state={})
assert result == int assert result is int
def test_list_comprehension(self): def test_list_comprehension(self):
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])" code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
...@@ -591,7 +591,7 @@ except ValueError as e: ...@@ -591,7 +591,7 @@ except ValueError as e:
code = "type_a = float(2); type_b = str; type_c = int" code = "type_a = float(2); type_b = str; type_c = int"
state = {} state = {}
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
assert result == int assert result is int
def test_tuple_id(self): def test_tuple_id(self):
code = """ code = """
......
...@@ -56,7 +56,7 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -56,7 +56,7 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
def test_rust_tokenizer(self): def test_rust_tokenizer(self): # noqa: F811
tokenizer = self.get_rust_tokenizer() tokenizer = self.get_rust_tokenizer()
input_text, output_text = self.get_chinese_input_output_texts() input_text, output_text = self.get_chinese_input_output_texts()
tokens = tokenizer.tokenize(input_text) tokens = tokenizer.tokenize(input_text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment