"...composable_kernel.git" did not exist on "f1ed4c5e369dad5f8ac8b78c99ce503f4bd130d6"
Unverified Commit 5c869bc7 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Refactor] Reopen #794 Fix lower bug when buffer store is not guarded by any tile op (#817)

* [Refactor] Rewrite AddWrapper pass by ir_transform
PyStmtExprVisitor and PyStmtExprMutator seem buggy

* fix lint error
parent 8b005226
from tvm.tir import PyStmtExprMutator, PyStmtExprVisitor, BufferStore, For, AttrStmt, Block, ForKind, IterVar, Var, PrimFunc
from tvm.tir.functor import mutator, visitor
from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc
from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass
@visitor
class FindVarUse(PyStmtExprVisitor):
def __init__(self):
self.used_var = set()
def visit_var_(self, op: Var):
self.used_var.add(op)
super().visit_var_(op)
def AddWrapperForSingleBufStore():
@mutator
class AddWrapperForSingleStoreMutator(PyStmtExprMutator):
'''
Add a dummy parallel for loop to wrap the single buffer store
Condition:
1. not inside a parallel for loop
2. no custom thread binding, i.e. threadIdx.x, blockIdx.x
'''
def pass_fn(func: PrimFunc, mod, ctx):
pfor = 0
thread_binding_var = set()
def __init__(self):
self.inside_pfor = 0
self.thread_binding_var = set()
def get_used_var(op):
used_var = set()
def visit_block_(self, op: Block):
super().visit_block_(op)
return op
def visit_fn(x):
if isinstance(x, Var):
used_var.add(x)
def visit_attr_stmt_(self, op: AttrStmt):
if op.attr_key == 'thread_extent':
iter_var: IterVar = op.node
self.thread_binding_var.add(iter_var.var)
super().visit_attr_stmt_(op)
return op
post_order_visit(op, visit_fn)
return used_var
def visit_for_(self, op: For):
pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
self.inside_pfor += pfor
super().visit_for_(op)
self.inside_pfor -= pfor
return op
def is_tile_op_for(op: For):
return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
def visit_buffer_store_(self, op: BufferStore):
# This pass runs after LetInline, we find var inside the stmt
fv = FindVarUse()
fv.visit_stmt(op)
used_binding = fv.used_var.intersection(self.thread_binding_var)
if not self.inside_pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
else:
super().visit_buffer_store_(op)
return op
def pre_visit(stmt):
nonlocal pfor
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
thread_binding_var.add(stmt.node.var)
if isinstance(stmt, For):
pfor += is_tile_op_for(stmt)
def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
if isinstance(stmt, BufferStore):
used_var = get_used_var(stmt)
used_binding = used_var.intersection(thread_binding_var)
if not pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)
def AddWrapperForSingleBufStore():
new_body = ir_transform(func.body, pre_visit, post_visit)
def pass_fn(func: PrimFunc, mod, ctx):
mut = AddWrapperForSingleStoreMutator()
new_body = mut.visit_stmt(func.body)
return func.with_body(new_body)
return prim_func_pass(pass_fn, opt_level=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment