Commit 7aa34977 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Implement thread-local storage for FrameStack in frame.py and kernel.py (#352)

* [Refactor] Implement thread-local storage for FrameStack in frame.py and kernel.py

- Replaced global FrameStack instances with thread-local storage to prevent cross-thread interference.
- Introduced `_get_let_stack` and `_get_current_stack` functions to manage thread-local FrameStack instances in LetFrame and KernelLaunchFrame classes.
- Updated all relevant methods to utilize the new thread-local stacks, ensuring thread safety in frame management.

* lint fix
parent 8dc1d7df
......@@ -7,6 +7,7 @@ from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque
from typing import Optional
import threading
class FrameStack:
......@@ -95,8 +96,15 @@ class FrameStack:
return bool(self._stack)
# Global stack for LetFrame instances
_let_frame_stack = FrameStack()
# Use thread local to store the stack
# This is to avoid the cross-thread interference
_local_let = threading.local()
def _get_let_stack() -> FrameStack:
if not hasattr(_local_let, "let_frame_stack"):
_local_let.let_frame_stack = FrameStack()
return _local_let.let_frame_stack
@_register_object("script.ir_builder.tir.LetFrame")
......@@ -125,7 +133,7 @@ class LetFrame(TIRFrame):
self.value = BufferRegion(self.value.buffer,
[Range(x.base, x.lanes) for x in indices])
_let_frame_stack.push(self)
_get_let_stack().push(self)
return self.var
def __exit__(self, ptype, value, trace):
......@@ -136,8 +144,9 @@ class LetFrame(TIRFrame):
value: Exception value if an exception occurred
trace: Exception traceback if an exception occurred
"""
if _let_frame_stack.top() is self:
_let_frame_stack.pop()
stack = _get_let_stack()
if stack.top() is self:
stack.pop()
super().__exit__(ptype, value, trace)
@classmethod
......@@ -150,7 +159,7 @@ class LetFrame(TIRFrame):
Raises:
IndexError: If there are no active let frames
"""
return _let_frame_stack.top()
return _get_let_stack().top()
@staticmethod
def get_value(var: Var):
......@@ -162,7 +171,7 @@ class LetFrame(TIRFrame):
Returns:
The value bound to the variable, or None if not found
"""
return _let_frame_stack.get_value(var)
return _get_let_stack().get_value(var)
@staticmethod
def has_value(var: Var) -> bool:
......@@ -174,7 +183,7 @@ class LetFrame(TIRFrame):
Returns:
bool: True if the variable has a binding, False otherwise
"""
return _let_frame_stack.has_value(var)
return _get_let_stack().has_value(var)
def has_let_value(var: Var) -> bool:
......@@ -186,7 +195,7 @@ def has_let_value(var: Var) -> bool:
Returns:
bool: True if the variable has a binding, False otherwise
"""
return _let_frame_stack.has_value(var)
return _get_let_stack().has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]:
......@@ -198,4 +207,4 @@ def get_let_value(var: Var) -> Optional[PrimExpr]:
Returns:
Optional[PrimExpr]: The bound value if found, None otherwise
"""
return _let_frame_stack.get_value(var)
return _get_let_stack().get_value(var)
......@@ -7,6 +7,7 @@ from tvm.tir import Var
from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame
from tvm._ffi import register_object
from tilelang import _ffi_api
import threading
class FrameStack:
......@@ -40,6 +41,10 @@ class FrameStack:
return self._stack[-1]
raise IndexError(f"{self.__class__.__name__} is empty")
def size(self):
"""Returns the number of items in the stack."""
return len(self._stack)
def __len__(self):
"""Returns the number of items in the stack."""
return len(self._stack)
......@@ -52,8 +57,15 @@ class FrameStack:
return bool(self._stack)
# Use our new FrameStack instead of a plain list or deque
_kernel_launch_frame_stack = FrameStack()
# Use thread local to store the stack
# This is to avoid the cross-thread interference
_local = threading.local()
def _get_current_stack() -> FrameStack:
if not hasattr(_local, "kernel_launch_frame_stack"):
_local.kernel_launch_frame_stack = FrameStack()
return _local.kernel_launch_frame_stack
@register_object("tl.KernelLaunchFrame")
......@@ -70,7 +82,7 @@ class KernelLaunchFrame(TIRFrame):
block dimension), or a list of Vars otherwise.
"""
super().__enter__()
_kernel_launch_frame_stack.push(self)
_get_current_stack().push(self)
# If we have exactly 5 frames, return the single iter_var.var.
if len(self.frames) == 5:
return self.frames[0].iter_var.var
......@@ -94,9 +106,9 @@ class KernelLaunchFrame(TIRFrame):
Exits the KernelLaunchFrame scope and pops this frame from the stack,
but only if it's indeed the topmost frame.
"""
# Check if this frame is the current top before popping.
if _kernel_launch_frame_stack.top() is self:
_kernel_launch_frame_stack.pop()
stack = _get_current_stack()
if stack.top() is self:
stack.pop()
super().__exit__(ptype, value, trace)
@classmethod
......@@ -105,7 +117,8 @@ class KernelLaunchFrame(TIRFrame):
Returns the topmost (current) KernelLaunchFrame from the stack if it exists,
or None if the stack is empty.
"""
return _kernel_launch_frame_stack.top()
stack = _get_current_stack()
return stack.top() if stack else None
def get_block_extent(self, dim: int) -> int:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment