Commit 7aa34977 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Implement thread-local storage for FrameStack in frame.py and kernel.py (#352)

* [Refactor] Implement thread-local storage for FrameStack in frame.py and kernel.py

- Replaced global FrameStack instances with thread-local storage to prevent cross-thread interference.
- Introduced `_get_let_stack` and `_get_current_stack` functions to manage thread-local FrameStack instances in LetFrame and KernelLaunchFrame classes.
- Updated all relevant methods to utilize the new thread-local stacks, ensuring thread safety in frame management.

* lint fix
parent 8dc1d7df
...@@ -7,6 +7,7 @@ from tvm import DataType ...@@ -7,6 +7,7 @@ from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque from collections import deque
from typing import Optional from typing import Optional
import threading
class FrameStack: class FrameStack:
...@@ -95,8 +96,15 @@ class FrameStack: ...@@ -95,8 +96,15 @@ class FrameStack:
return bool(self._stack) return bool(self._stack)
# Global stack for LetFrame instances # Use thread local to store the stack
_let_frame_stack = FrameStack() # This is to avoid the cross-thread interference
_local_let = threading.local()
def _get_let_stack() -> FrameStack:
if not hasattr(_local_let, "let_frame_stack"):
_local_let.let_frame_stack = FrameStack()
return _local_let.let_frame_stack
@_register_object("script.ir_builder.tir.LetFrame") @_register_object("script.ir_builder.tir.LetFrame")
...@@ -125,7 +133,7 @@ class LetFrame(TIRFrame): ...@@ -125,7 +133,7 @@ class LetFrame(TIRFrame):
self.value = BufferRegion(self.value.buffer, self.value = BufferRegion(self.value.buffer,
[Range(x.base, x.lanes) for x in indices]) [Range(x.base, x.lanes) for x in indices])
_let_frame_stack.push(self) _get_let_stack().push(self)
return self.var return self.var
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
...@@ -136,8 +144,9 @@ class LetFrame(TIRFrame): ...@@ -136,8 +144,9 @@ class LetFrame(TIRFrame):
value: Exception value if an exception occurred value: Exception value if an exception occurred
trace: Exception traceback if an exception occurred trace: Exception traceback if an exception occurred
""" """
if _let_frame_stack.top() is self: stack = _get_let_stack()
_let_frame_stack.pop() if stack.top() is self:
stack.pop()
super().__exit__(ptype, value, trace) super().__exit__(ptype, value, trace)
@classmethod @classmethod
...@@ -150,7 +159,7 @@ class LetFrame(TIRFrame): ...@@ -150,7 +159,7 @@ class LetFrame(TIRFrame):
Raises: Raises:
IndexError: If there are no active let frames IndexError: If there are no active let frames
""" """
return _let_frame_stack.top() return _get_let_stack().top()
@staticmethod @staticmethod
def get_value(var: Var): def get_value(var: Var):
...@@ -162,7 +171,7 @@ class LetFrame(TIRFrame): ...@@ -162,7 +171,7 @@ class LetFrame(TIRFrame):
Returns: Returns:
The value bound to the variable, or None if not found The value bound to the variable, or None if not found
""" """
return _let_frame_stack.get_value(var) return _get_let_stack().get_value(var)
@staticmethod @staticmethod
def has_value(var: Var) -> bool: def has_value(var: Var) -> bool:
...@@ -174,7 +183,7 @@ class LetFrame(TIRFrame): ...@@ -174,7 +183,7 @@ class LetFrame(TIRFrame):
Returns: Returns:
bool: True if the variable has a binding, False otherwise bool: True if the variable has a binding, False otherwise
""" """
return _let_frame_stack.has_value(var) return _get_let_stack().has_value(var)
def has_let_value(var: Var) -> bool: def has_let_value(var: Var) -> bool:
...@@ -186,7 +195,7 @@ def has_let_value(var: Var) -> bool: ...@@ -186,7 +195,7 @@ def has_let_value(var: Var) -> bool:
Returns: Returns:
bool: True if the variable has a binding, False otherwise bool: True if the variable has a binding, False otherwise
""" """
return _let_frame_stack.has_value(var) return _get_let_stack().has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]: def get_let_value(var: Var) -> Optional[PrimExpr]:
...@@ -198,4 +207,4 @@ def get_let_value(var: Var) -> Optional[PrimExpr]: ...@@ -198,4 +207,4 @@ def get_let_value(var: Var) -> Optional[PrimExpr]:
Returns: Returns:
Optional[PrimExpr]: The bound value if found, None otherwise Optional[PrimExpr]: The bound value if found, None otherwise
""" """
return _let_frame_stack.get_value(var) return _get_let_stack().get_value(var)
...@@ -7,6 +7,7 @@ from tvm.tir import Var ...@@ -7,6 +7,7 @@ from tvm.tir import Var
from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame
from tvm._ffi import register_object from tvm._ffi import register_object
from tilelang import _ffi_api from tilelang import _ffi_api
import threading
class FrameStack: class FrameStack:
...@@ -40,6 +41,10 @@ class FrameStack: ...@@ -40,6 +41,10 @@ class FrameStack:
return self._stack[-1] return self._stack[-1]
raise IndexError(f"{self.__class__.__name__} is empty") raise IndexError(f"{self.__class__.__name__} is empty")
def size(self):
"""Returns the number of items in the stack."""
return len(self._stack)
def __len__(self): def __len__(self):
"""Returns the number of items in the stack.""" """Returns the number of items in the stack."""
return len(self._stack) return len(self._stack)
...@@ -52,8 +57,15 @@ class FrameStack: ...@@ -52,8 +57,15 @@ class FrameStack:
return bool(self._stack) return bool(self._stack)
# Use our new FrameStack instead of a plain list or deque # Use thread local to store the stack
_kernel_launch_frame_stack = FrameStack() # This is to avoid the cross-thread interference
_local = threading.local()
def _get_current_stack() -> FrameStack:
if not hasattr(_local, "kernel_launch_frame_stack"):
_local.kernel_launch_frame_stack = FrameStack()
return _local.kernel_launch_frame_stack
@register_object("tl.KernelLaunchFrame") @register_object("tl.KernelLaunchFrame")
...@@ -70,7 +82,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -70,7 +82,7 @@ class KernelLaunchFrame(TIRFrame):
block dimension), or a list of Vars otherwise. block dimension), or a list of Vars otherwise.
""" """
super().__enter__() super().__enter__()
_kernel_launch_frame_stack.push(self) _get_current_stack().push(self)
# If we have exactly 5 frames, return the single iter_var.var. # If we have exactly 5 frames, return the single iter_var.var.
if len(self.frames) == 5: if len(self.frames) == 5:
return self.frames[0].iter_var.var return self.frames[0].iter_var.var
...@@ -94,9 +106,9 @@ class KernelLaunchFrame(TIRFrame): ...@@ -94,9 +106,9 @@ class KernelLaunchFrame(TIRFrame):
Exits the KernelLaunchFrame scope and pops this frame from the stack, Exits the KernelLaunchFrame scope and pops this frame from the stack,
but only if it's indeed the topmost frame. but only if it's indeed the topmost frame.
""" """
# Check if this frame is the current top before popping. stack = _get_current_stack()
if _kernel_launch_frame_stack.top() is self: if stack.top() is self:
_kernel_launch_frame_stack.pop() stack.pop()
super().__exit__(ptype, value, trace) super().__exit__(ptype, value, trace)
@classmethod @classmethod
...@@ -105,7 +117,8 @@ class KernelLaunchFrame(TIRFrame): ...@@ -105,7 +117,8 @@ class KernelLaunchFrame(TIRFrame):
Returns the topmost (current) KernelLaunchFrame from the stack if it exists, Returns the topmost (current) KernelLaunchFrame from the stack if it exists,
or None if the stack is empty. or None if the stack is empty.
""" """
return _kernel_launch_frame_stack.top() stack = _get_current_stack()
return stack.top() if stack else None
def get_block_extent(self, dim: int) -> int: def get_block_extent(self, dim: int) -> int:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment