Unverified Commit 4c722e9e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix regexes with escape sequence (#17943)

parent 7c4c6f60
...@@ -78,9 +78,9 @@ def get_relative_imports(module_file): ...@@ -78,9 +78,9 @@ def get_relative_imports(module_file):
content = f.read() content = f.read()
# Imports of the form `import .xxx` # Imports of the form `import .xxx`
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from .xxx import yyy` # Imports of the form `from .xxx import yyy`
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
# Unique-ify # Unique-ify
return list(set(relative_imports)) return list(set(relative_imports))
...@@ -122,9 +122,9 @@ def check_imports(filename): ...@@ -122,9 +122,9 @@ def check_imports(filename):
content = f.read() content = f.read()
# Imports of the form `import xxx` # Imports of the form `import xxx`
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy` # Imports of the form `from xxx import yyy`
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module # Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
......
...@@ -219,7 +219,7 @@ def dtype_byte_size(dtype): ...@@ -219,7 +219,7 @@ def dtype_byte_size(dtype):
""" """
if dtype == torch.bool: if dtype == torch.bool:
return 1 / 8 return 1 / 8
bit_search = re.search("[^\d](\d+)$", str(dtype)) bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None: if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0]) bit_size = int(bit_search.groups()[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment