Commit 0181d721 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

Remove debug print statement from OptimizeForTarget function and enhance...

Remove debug print statement from OptimizeForTarget function and enhance library loading mechanism in Cython adapter. Implemented file locking during cache access and added checks for library size before loading. Introduced temporary file handling for safer compilation of Cython JIT adapter. (#377)
parent 0997c333
...@@ -87,7 +87,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -87,7 +87,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod) mod = tir.transform.MergeSharedMemoryAllocations()(mod)
print(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
...@@ -60,13 +60,23 @@ def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]: ...@@ -60,13 +60,23 @@ def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]:
"""Try to load cached library or return None if not found.""" """Try to load cached library or return None if not found."""
code_hash = hashlib.sha256(source_code.encode()).hexdigest() code_hash = hashlib.sha256(source_code.encode()).hexdigest()
cache_path = get_cache_dir() / f"{code_hash}.so" cache_path = get_cache_dir() / f"{code_hash}.so"
if cache_path.exists(): lock_file = cache_path.with_suffix('.lock')
import fcntl
with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try: try:
return ctypes.CDLL(str(cache_path)), cache_path if cache_path.exists():
except Exception as e: try:
logger.error(f"Failed to load cached library: {e}") if cache_path.stat().st_size > 1024:
return ctypes.CDLL(str(cache_path)), cache_path
else:
cache_path.unlink() # 删除不完整文件
except Exception as e:
logger.error(f"Failed to load cached library: {e}")
return None, cache_path
return None, cache_path return None, cache_path
return None, cache_path finally:
fcntl.flock(lock.fileno(), fcntl.LOCK_UN)
# read the cython_wrapper.pyx file # read the cython_wrapper.pyx file
...@@ -96,21 +106,30 @@ with open(cython_wrapper_path, "r") as f: ...@@ -96,21 +106,30 @@ with open(cython_wrapper_path, "r") as f:
if need_compile: if need_compile:
logger.info("Compiling cython jit adapter...") logger.info("Compiling cython jit adapter...")
with open(md5_path, "w") as f: temp_path = cache_dir / f"temp_{code_hash}.so"
f.write(code_hash)
# compile the cython_wrapper.pyx file into .cpp
cython = get_cython_compiler()
if cython is None:
raise Exception("Cython is not installed, please install it first.")
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
# compile the .cpp file into .so
python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {library_path}"
try: try:
with open(md5_path, "w") as f:
f.write(code_hash)
# 使用临时文件进行编译
cython = get_cython_compiler()
if cython is None:
raise Exception("Cython is not installed, please install it first.")
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
os.system(command) os.system(command)
# 原子替换操作
temp_path.rename(library_path)
except Exception as e: except Exception as e:
if temp_path.exists():
temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e raise Exception(f"Failed to compile cython jit adapter: {e}") from e
finally:
if lock_file.exists():
lock_file.unlink()
# add the .so file to the sys.path # add the .so file to the sys.path
cache_dir_str = str(cache_dir) cache_dir_str = str(cache_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment