"...composable_kernel_rocm.git" did not exist on "c3725a45b84b9597b51a0f3cf694dfc419d6bc4d"
Unverified Commit 596342c2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Support for Windows in check_copies (#7316)

parent 89edf504
...@@ -40,7 +40,7 @@ def find_code_in_transformers(object_name): ...@@ -40,7 +40,7 @@ def find_code_in_transformers(object_name):
f"`object_name` should begin with the name of a module of transformers but got {object_name}." f"`object_name` should begin with the name of a module of transformers but got {object_name}."
) )
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r") as f: with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
# Now let's find the class / func in the code! # Now let's find the class / func in the code!
...@@ -80,10 +80,10 @@ def blackify(code): ...@@ -80,10 +80,10 @@ def blackify(code):
code = f"class Bla:\n{code}" code = f"class Bla:\n{code}"
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
fname = os.path.join(d, "tmp.py") fname = os.path.join(d, "tmp.py")
with open(fname, "w") as f: with open(fname, "w", encoding="utf-8") as f:
f.write(code) f.write(code)
os.system(f"black -q --line-length 119 --target-version py35 {fname}") os.system(f"black -q --line-length 119 --target-version py35 {fname}")
with open(fname, "r") as f: with open(fname, "r", encoding="utf-8") as f:
result = f.read() result = f.read()
return result[len("class Bla:\n") :] if has_indent else result return result[len("class Bla:\n") :] if has_indent else result
...@@ -94,7 +94,7 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -94,7 +94,7 @@ def is_copy_consistent(filename, overwrite=False):
Return the differences or overwrites the content depending on `overwrite`. Return the differences or overwrites the content depending on `overwrite`.
""" """
with open(filename) as f: with open(filename, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
found_diff = False found_diff = False
line_index = 0 line_index = 0
...@@ -152,7 +152,7 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -152,7 +152,7 @@ def is_copy_consistent(filename, overwrite=False):
if overwrite and found_diff: if overwrite and found_diff:
# Warn the user a file has been modified. # Warn the user a file has been modified.
print(f"Detected changes, rewriting {filename}.") print(f"Detected changes, rewriting {filename}.")
with open(filename, "w") as f: with open(filename, "w", encoding="utf-8") as f:
f.writelines(lines) f.writelines(lines)
return not found_diff return not found_diff
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment