Unverified Commit cd681e63 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Fix memory leak bug (#1281)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files

* fix memory leak bug

* fix lint error

* add comments

* fix lint error

* remove duplicated, because tilelang doesn't dependent deprecated
parent 4c8b9ada
import tilelang.language as T
import tilelang.testing
import torch
import weakref
import gc
def test_tilelang_capture():
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},)
def get_dummy_kernel():
@T.prim_func
def dummy_kernel(a: T.Tensor[(1,), T.float32],):
with T.Kernel(1) as _:
a[0] = 1
return dummy_kernel
a = torch.randn(1, 1024)
a_weak = weakref.ref(a)
_kernel = get_dummy_kernel()
del a
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
a_upgrade = a_weak()
assert a_upgrade is None, "A is not garbage collected"
# use objgraph to debug
# if a_upgrade is not None:
# objgraph.show_backrefs([a_upgrade], max_depth=5)
if __name__ == '__main__':
tilelang.testing.main()
...@@ -248,8 +248,9 @@ class BaseBuilder: ...@@ -248,8 +248,9 @@ class BaseBuilder:
class DSLMutator(ast.NodeTransformer): class DSLMutator(ast.NodeTransformer):
def __init__(self): def __init__(self, closure_names: list[str]):
self.tmp_counter = 0 self.tmp_counter = 0
self.closure_names = closure_names
def get_tmp(self) -> str: def get_tmp(self) -> str:
name = f"__{self.tmp_counter}" name = f"__{self.tmp_counter}"
...@@ -494,9 +495,11 @@ class DSLMutator(ast.NodeTransformer): ...@@ -494,9 +495,11 @@ class DSLMutator(ast.NodeTransformer):
node.body = stmts + node.body node.body = stmts + node.body
node.decorator_list.clear() node.decorator_list.clear()
return quote1( return quote1(
f"def {node.name}(__tb):\n" f"def make_closure({', '.join(self.closure_names)}):\n"
" range = __tb.override('range')\n" f" def {node.name}(__tb):\n"
" pass\n" " range = __tb.override('range')\n"
" pass\n"
f" return {node.name}\n"
f" return {node.name}", f" return {node.name}",
passes=[node], passes=[node],
) )
...@@ -595,7 +598,29 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: ...@@ -595,7 +598,29 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]:
tree = utils.get_ast(func) tree = utils.get_ast(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func) filename = inspect.getsourcefile(func) or inspect.getfile(func)
tree = DSLMutator().visit(tree) nonlocals = utils.get_func_nonlocals(func)
fn = utils.get_compiled_object(tree, func.__name__, filename,
utils.inspect_function_capture(func)) # DSLMutator generates a function named `make_closure`
# it accepts all names inside nonlocal, and returns the mutated function
# this is because we must separate the closure namespace form the global namespace
# if we directly inject closure variables into the global namespace,
# it generates a new `globals` dict, and the dict owns all reference to the original globalns
# which makes memory leak, because the original globalns cannot be freed
# ```py
# a = 123
# def foo():
# x = foo.__globals__ # OK, globals are maintained by python
# x = {**foo.__globals__, } # Not OK: globals are copied, and the original globals cannot be freed
# def bar(): x
# return bar
# ```
tree = DSLMutator(nonlocals.keys()).visit(tree)
make_closure = utils.get_compiled_object(
tree,
'make_closure',
filename,
func.__globals__, # use the original globalns
)
fn = make_closure(**nonlocals)
return IRGenerator(gen=fn, source=ast.unparse(tree)) return IRGenerator(gen=fn, source=ast.unparse(tree))
...@@ -18,6 +18,7 @@ try: ...@@ -18,6 +18,7 @@ try:
except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec
from typing_extensions import ParamSpec, Self from typing_extensions import ParamSpec, Self
from . import dtypes as dt from . import dtypes as dt
from . import utils
import threading import threading
import logging import logging
...@@ -593,22 +594,27 @@ def get_type_hints(func): ...@@ -593,22 +594,27 @@ def get_type_hints(func):
# Build eval namespaces from function globals plus captured closure variables # Build eval namespaces from function globals plus captured closure variables
# This lets annotations reference symbols like `n`, `h`, or dtype vars # This lets annotations reference symbols like `n`, `h`, or dtype vars
# defined in the outer scope of a nested function. # defined in the outer scope of a nested function.
globalns = dict(getattr(func, '__globals__', {})) globalns = func.__globals__
localns = dict(globalns) # Here we add nonlocals into localns, to capture the parameters declared in the parent function
try: # ```py
freevars = getattr(func.__code__, 'co_freevars', ()) # def foo():
cells = getattr(func, '__closure__', ()) or () # n = 128 # n is nonlocal
closure_bindings = { # def bar(
name: cell.cell_contents for name, cell in zip(freevars, cells) if name not in localns # A: T.Tensor(n, T.float32) # we add nonlocal in its eval context
} # ):
if closure_bindings: # for i in range(n): ...
localns.update(closure_bindings) # ```
# Also update globals so ForwardRef eval sees them uniformly #
globalns.update(closure_bindings) # This is incomplete and buggy
except Exception: # the only bug scenario the function body doesn't use the the parameters
# Be permissive: absence or access issues with closure shouldn't crash # but such define-no-use scenario is very rare in writing kernels
pass #
# ```py
# def foo():
# n = 128
# def bar(A: T.Tensor((n,), T.float32)):
# ... # empty function, do not use `n`
localns = utils.get_func_nonlocals(func)
for name, value in annot.items(): for name, value in annot.items():
if name == 'return': if name == 'return':
continue continue
...@@ -618,8 +624,10 @@ def get_type_hints(func): ...@@ -618,8 +624,10 @@ def get_type_hints(func):
if value is None: if value is None:
value = type(None) value = type(None)
if isinstance(value, str): if isinstance(value, str):
# Handle simple dtype aliases like T.float32 appearing as strings # if the annotation is string, is can be: (i) a T.float32 like annotations, (ii) a ForwardRef object
# Evaluate directly only when it matches known dtypes # typing doesn't handle (i), it will try to interpret T.float32
# typing see: T.float32 is str('float32'), and there is no object named `flaot32` and give a NameError
# here we manually interpret it to return T.float32 object
try: try:
_, v = value.split('.', maxsplit=1) _, v = value.split('.', maxsplit=1)
except ValueError: except ValueError:
...@@ -631,7 +639,9 @@ def get_type_hints(func): ...@@ -631,7 +639,9 @@ def get_type_hints(func):
except Exception: except Exception:
pass pass
value = ForwardRef(value, is_argument=True, is_class=False) value = ForwardRef(value, is_argument=True, is_class=False)
hints[name] = _eval_type(value, globalns=globalns, localns=localns) hints[name] = _eval_type(value, globalns=globalns, localns=localns)
else:
hints[name] = value
return hints return hints
......
...@@ -53,26 +53,6 @@ def get_func_nonlocals(func): ...@@ -53,26 +53,6 @@ def get_func_nonlocals(func):
return nonlocal_vars return nonlocal_vars
def inspect_function_capture(func: Callable) -> dict[str, Any]:
"""Capture function non-locals and global variables.
Parameters
----------
func : Callable
The function to inspect.
Returns
-------
res : Dict[str, Any]
The function variables map with non-local or global variables.
"""
captured = {
**func.__globals__, # type: ignore
**get_func_nonlocals(func),
}
return captured
def get_ast(func: Callable): def get_ast(func: Callable):
_, start = inspect.getsourcelines(func) _, start = inspect.getsourcelines(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func) filename = inspect.getsourcefile(func) or inspect.getfile(func)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment