Unverified Commit 5c869bc7 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Refactor] Reopen #794 Fix lower bug when buffer store is not guarded by any tile op (#817)

* [Refactor] Rewrite AddWrapper pass by ir_transform
PyStmtExprVisitor and PyStmtExprMutator seem buggy

* fix lint error
parent 8b005226
from tvm.tir import PyStmtExprMutator, PyStmtExprVisitor, BufferStore, For, AttrStmt, Block, ForKind, IterVar, Var, PrimFunc from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc
from tvm.tir.functor import mutator, visitor from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass from tvm.tir.transform import prim_func_pass
@visitor def AddWrapperForSingleBufStore():
class FindVarUse(PyStmtExprVisitor):
def __init__(self):
self.used_var = set()
def visit_var_(self, op: Var):
self.used_var.add(op)
super().visit_var_(op)
@mutator def pass_fn(func: PrimFunc, mod, ctx):
class AddWrapperForSingleStoreMutator(PyStmtExprMutator): pfor = 0
''' thread_binding_var = set()
Add a dummy parallel for loop to wrap the single buffer store
Condition:
1. not inside a parallel for loop
2. no custom thread binding, i.e. threadIdx.x, blockIdx.x
'''
def __init__(self): def get_used_var(op):
self.inside_pfor = 0 used_var = set()
self.thread_binding_var = set()
def visit_block_(self, op: Block): def visit_fn(x):
super().visit_block_(op) if isinstance(x, Var):
return op used_var.add(x)
def visit_attr_stmt_(self, op: AttrStmt): post_order_visit(op, visit_fn)
if op.attr_key == 'thread_extent': return used_var
iter_var: IterVar = op.node
self.thread_binding_var.add(iter_var.var)
super().visit_attr_stmt_(op)
return op
def visit_for_(self, op: For): def is_tile_op_for(op: For):
pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
self.inside_pfor += pfor
super().visit_for_(op)
self.inside_pfor -= pfor
return op
def visit_buffer_store_(self, op: BufferStore): def pre_visit(stmt):
# This pass runs after LetInline, we find var inside the stmt nonlocal pfor
fv = FindVarUse() if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
fv.visit_stmt(op) thread_binding_var.add(stmt.node.var)
used_binding = fv.used_var.intersection(self.thread_binding_var) if isinstance(stmt, For):
if not self.inside_pfor and len(used_binding) == 0: pfor += is_tile_op_for(stmt)
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
else:
super().visit_buffer_store_(op)
return op
def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
if isinstance(stmt, BufferStore):
used_var = get_used_var(stmt)
used_binding = used_var.intersection(thread_binding_var)
if not pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)
def AddWrapperForSingleBufStore(): new_body = ir_transform(func.body, pre_visit, post_visit)
def pass_fn(func: PrimFunc, mod, ctx):
mut = AddWrapperForSingleStoreMutator()
new_body = mut.visit_stmt(func.body)
return func.with_body(new_body) return func.with_body(new_body)
return prim_func_pass(pass_fn, opt_level=0) return prim_func_pass(pass_fn, opt_level=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment