Commit a1a3e2e6 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev][Language] Separate Base AST with Sugar Syntax (#9)

* update readme

* center and resize benchmark figures

* replace svg figures with PNG

* replace svg figures with png

* remove svg, and add png

* Add Base AST Components

* lint fix
parent d191af8d
......@@ -14,7 +14,7 @@ column_limit = 100
indent_width = 4
[tool.codespell]
ignore-words-list = "nd, te, ist, LOD, offen"
ignore-words-list = "nd, te, ist, LOD, offen, NotIn"
skip = [
"build",
"3rdparty",
......
......@@ -3,8 +3,8 @@
"""The language interface for tl programs."""
from typing import Optional
from tvm.script import tir as T
from tvm.script.parser.tir import *
from .parser import *
# from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
......@@ -36,16 +36,16 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
# The panel size is the number of threads in a warp
# Use to improve the L2 Cache Locality
device_func = ("rasterization2DRow" if order == "row" else "rasterization2DColumn")
return T.attr(None, "threadblock_swizzle_pattern",
f"tl::{device_func}<{panel_size}>") if enable else None
return attr(None, "threadblock_swizzle_pattern",
f"tl::{device_func}<{panel_size}>") if enable else None
def annotate_layout(layout_map):
# layout_map is a dictionary of buffer to layout
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()}
return T.block_attr({"layout_map": layout_map})
return block_attr({"layout_map": layout_map})
def import_source(source: Optional[str] = None):
# source is the source code to be imported
return T.block_attr({"pragma_import_c": source}) if source is not None else None
return block_attr({"pragma_import_c": source}) if source is not None else None
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""Package tvm.script.ir_builder.tir"""
from .ir import * # noqa: F401
from .ir import boolean as bool # noqa: F401
from .ir import buffer as Buffer # noqa: F401
from tvm.script.ir_builder.tir import frame # noqa: F401
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs"""
import tvm._ffi
tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access
This diff is collapsed.
......@@ -78,7 +78,8 @@ class KernelLaunchFrame(TIRFrame):
return self.frames[0].iter_var.var
last_block_frame = self.frames[-1]
assert isinstance(last_block_frame, BlockFrame), "Last frame must be a block frame"
assert isinstance(last_block_frame,
BlockFrame), f"Last frame must be a block frame, got {last_block_frame}"
maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
"""The tir parser"""
from typing import TYPE_CHECKING
from ..ast import * # pylint: disable=redefined-builtin
from ..ast import ir as _tir
from . import operation as _operation
from . import parser as _parser
from .entry import Buffer, Ptr
if TYPE_CHECKING:
# pylint: disable=invalid-name
# Define prim_func and make it type check as static method
# so most tvmscript won't trigger pylint error here.
prim_func = staticmethod
else:
from .entry import macro, prim_func
__all__ = _tir.__all__ + ["Buffer", "Ptr", "bool", "prim_func", "macro"]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
"""The entry point of TVM parser for tir."""
import inspect
from typing import Callable, Optional, Union
from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc
from ..ast import buffer, ptr
from tvm.script.parser._core import parse, scan_macro, utils
from tvm.script.parser.core.parser import Parser, ScriptMacro
def prim_func(func: Optional[Callable] = None,
private: bool = False,
check_well_formed=True) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
----------
func : Callable
The function to be parsed as prim func.
(Listed as optional to allow the decorator to be used
without arguments, like `@prim_func`,
or with an argument, `@prim_func(private=True)`)
private : bool, optional
Whether the function should be treated as private.
A private function has no global symbol attribute;
if the function is not private, it will have a global symbol
matching the function name.
Returns
-------
res : Union[PrimFunc, Callable]
The parsed tir prim func.
"""
# pylint: disable=unused-argument
# (private will be used in the parser, but not immediately)
# need to capture this var outside the wrapper because the wrapper
# adds to the stack
outer_stack = inspect.stack()
def decorator_wrapper(func):
if not inspect.isfunction(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__)
return f
if func is not None:
# no optional args given => use wrapper directly
return decorator_wrapper(func)
else:
# if there is an optional arg given, return a new decorator
# that will then be invoked
setattr(decorator_wrapper, "dispatch_token", "tir")
return decorator_wrapper
setattr(prim_func, "dispatch_token", "tir")
# Semantics of TIR macros:
# - Function that is decorated with @T.macro can have any parameters that
# follow Python syntax, i.e. positional, keyword, etc. Type annotations
# are not required, but are allowed.
# - Macro use follows the same syntax as a function call.
# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into
# the body of the macro, and the body with the substituted values is then
# inserted at the point where the call to the macro is located.
class TIRMacro(ScriptMacro):
"""Specialization of the ScriptMacro class for TIR."""
def parse_macro(self, parser: Parser) -> None:
macro_def = self.get_macro_def()
parser.visit_body(macro_def.body)
def macro(*args, hygienic: bool = True) -> Callable:
"""Decorator for macro definitions.
Parameters
----------
hygienic: bool
Specifies whether the macro is hygienic or not.
A macro is hygienic if all symbols used in the macro's body are resolved
to values from the location of the macro definition. A non-hygienic macro
will have its symbols resolved to values at the time of the macro's use.
Example:
```
import tvm
from tvm.script import tir as T
x_value = 128
@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value] ### x_value binds to 128
@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value] ### x_value will bind at the time of use
@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]
@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
"""
def _decorator(func: Callable) -> TIRMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj
if len(args) == 0:
return _decorator
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])
raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""
def __call__(
self,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Buffer:
return buffer(
shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
@deprecated("T.Buffer[...]", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
class PtrProxy:
"""Ptr proxy class for constructing tir pointer."""
@deprecated("T.Ptr(...)", "T.handle(...)")
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
return self(*keys)
Buffer = BufferProxy() # pylint: disable=invalid-name
Ptr = PtrProxy() # pylint: disable=invalid-name
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
from typing import Type
from tvm import tir
from tvm._ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm
from tvm.tir.expr import FloatImm
from tvm.script.parser._core import OpMethod, doc, register_op
def _register_expr_op(ty: Type): # pylint: disable=invalid-name
ty._dispatch_type = ty # pylint: disable=protected-access
def _and(a, b):
if isinstance(a, bool):
a = IntImm("bool", a)
if isinstance(b, bool):
b = IntImm("bool", b)
if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
return a & b
else:
return tir.And(a, b)
def _or(a, b):
if isinstance(a, bool):
a = IntImm("bool", a)
if isinstance(b, bool):
b = IntImm("bool", b)
if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
return a | b
else:
return tir.Or(a, b)
def _get_type_str(dtype: str):
if DataType(dtype).lanes == 1:
return dtype
index = dtype.find("x")
return dtype[0:index]
def _auto_broadcast(a, b, op):
if isinstance(a, int):
if hasattr(b, "dtype"):
if (DataType(b.dtype).type_code == DataTypeCode.INT or
DataType(b.dtype).type_code == DataTypeCode.UINT):
a = IntImm(_get_type_str(b.dtype), a)
elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
a = FloatImm(_get_type_str(b.dtype), a)
elif isinstance(b, float):
a = FloatImm("float32", a)
else:
a = IntImm("int32", a)
elif isinstance(a, float):
if DataType(b.dtype).type_code == DataTypeCode.FLOAT:
a = FloatImm(_get_type_str(b.dtype), a)
else:
a = FloatImm("float32", a)
assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
if isinstance(b, int):
if (DataType(a.dtype).type_code == DataTypeCode.INT or
DataType(a.dtype).type_code == DataTypeCode.UINT):
b = IntImm(_get_type_str(a.dtype), b)
elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
b = FloatImm(_get_type_str(a.dtype), b)
elif isinstance(b, float):
b = FloatImm(_get_type_str(a.dtype), b)
if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
return op(a, b)
elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
return op(broadcast_a, b)
elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
return op(a, broadcast_b)
else:
raise TypeError("do not know how to deal with it.")
def _eq(a, b):
return _auto_broadcast(a, b, tir.EQ)
def _ne(a, b):
return _auto_broadcast(a, b, tir.NE)
def _lt(a, b):
return _auto_broadcast(a, b, tir.LT)
def _le(a, b):
return _auto_broadcast(a, b, tir.LE)
def _gt(a, b):
return _auto_broadcast(a, b, tir.GT)
def _ge(a, b):
return _auto_broadcast(a, b, tir.GE)
def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
register_op(ty, op, i)(m)
for i in [0, 1]:
# Case 1. binop
# doc.Add <-- is overloaded
# doc.Sub <-- is overloaded
# doc.Mult <-- is overloaded
# doc.Div <-- is overloaded
# doc.FloorDiv <-- is overloaded
# doc.Mod <-- is overloaded
# doc.LShift <-- is overloaded
# doc.RShift <-- is overloaded
# doc.BitOr <-- is overloaded
# doc.BitXor <-- is overloaded
# doc.BitAnd <-- is overloaded
# doc.MatMult <-- not implemented
# doc.Pow <-- not implemented
# Case 2. cmpop
r(doc.Eq, i, _eq)
r(doc.NotEq, i, _ne)
r(doc.Lt, i, _lt)
r(doc.LtE, i, _le)
r(doc.Gt, i, _gt)
r(doc.GtE, i, _ge)
# doc.Is <-- not implemented
# doc.IsNot <-- not implemented
# doc.In <-- not implemented
# doc.NotIn <-- not implemented
# Case 3. boolop
r(doc.And, i, _and)
r(doc.Or, i, _or)
for i in [0]:
# Case 4. unaryop
# doc.Invert <-- is overloaded
r(doc.Not, i, tir.Not)
# doc.UAdd <-- is overloaded
# doc.USub <-- is overloaded
_register_expr_op(tir.PrimExpr)
_register_expr_op(tir.IterVar)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
"""The base parser for tir"""
import contextlib
from functools import partial
from typing import Any
import tvm
from tvm.ir import GlobalVar, PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var
from tvm.script.ir_builder import ir as I
from .. import ast as T
from tvm.script.ir_builder.base import IRBuilder
from tvm.script.ir_builder.base import IRBuilderFrame as Frame
from tvm.script.parser._core import Parser, dispatch, doc
def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
"""Value binding methods when parsing with statement.
e.g. binding i, j, k with T.grid(128, 128, 128), when parsing
with T.grid(128, 128, 18) as i, j, k.
Parameters
----------
self : Parser
The current parser.
node : doc.expr
The doc AST expression node for error reporting.
var_name : str
The variable name.
value : Any
The value to be bound with.
Returns
-------
res : Any
The bound value.
"""
if isinstance(value, (list, tuple)):
for i, v in enumerate(value):
bind_with_value(self, node, f"{var_name}_{i}", v)
return value
elif isinstance(value, (Buffer, Var)):
IRBuilder.name(var_name, value)
return value
else:
self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
raise NotImplementedError
def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
"""Value binding methods when parsing for statement.
e.g. binding i, j, k with T.grid(128, 128, 128), when parsing
for i, j, k in T.grid(128, 128, 128).
Parameters
----------
self : Parser
The current parser.
node : doc.expr
The doc AST expression node for error reporting.
var_name : str
The variable name.
value : Any
The value to be bound with.
Returns
-------
res : Any
The bound value.
"""
if isinstance(value, (list, tuple, tvm.ir.Array)):
for i, v in enumerate(value):
bind_for_value(self, node, f"{var_name}_{i}", v)
return value
elif isinstance(value, Var):
IRBuilder.name(var_name, value)
return value
else:
self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
raise NotImplementedError
def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
"""Value binding methods when parsing assign statement.
e.g. binding vi, vj, vk with T.axis.remap("SSR", [i, j, k]), when parsing
vi, vj, vk = T.axis.remap("SSR", [i, j, k]).
Parameters
----------
self : Parser
The current parser.
node : doc.expr
The doc AST expression node for error reporting.
var_name : str
The variable name.
value : Any
The value to be bound with.
Returns
-------
res : Any
The bound value.
"""
if isinstance(value, T.meta_var):
return value.value
elif isinstance(value, (list, tuple)):
for i, v in enumerate(value):
bind_assign_value(self, node, f"{var_name}_{i}", v)
return value
elif isinstance(value, Frame):
value.add_callback(partial(value.__exit__, None, None, None))
res = value.__enter__()
IRBuilder.name(var_name, res)
return res
elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and
not self.var_table.exist(value)):
IRBuilder.name(var_name, value)
return value
else:
value = tvm.runtime.convert(value)
frame = T.LetStmt(value)
var = frame.var
IRBuilder.name(var_name, var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
return var
def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool:
"""
Check the value of given annotation (argument name) in the prim_func decorator.
Returns the value of the annotation if present, otherwise giving the default value.
"""
# look for the named argument in the prim_func decorator
for dec in node.decorator_list:
if not isinstance(dec, doc.Call) or dec.func.attr != "prim_func":
continue
for keyword in dec.keywords:
if keyword.arg == annotation:
return keyword.value.value
return default
@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
"""The for visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.For
The doc AST for node.
"""
for_frame = self.eval_expr(node.iter)
if not isinstance(for_frame, T.frame.ForFrame):
self.report_error(
node.iter,
"Expect the for loop to be one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
)
with self.var_table.with_frame():
with for_frame as iters:
self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
self.visit_body(node.body)
@dispatch.register(token="tir", type_name="While")
def visit_while(self: Parser, node: doc.While) -> None:
"""The while visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.While
The doc AST while node.
"""
with self.var_table.with_frame():
cond = self.eval_expr(node.test)
with T.While(cond):
self.visit_body(node.body)
@dispatch.register(token="tir", type_name="Assign")
def visit_assign(self: Parser, node: doc.Assign) -> None:
"""The assign visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.Assign
The doc AST assign node.
"""
if len(node.targets) != 1:
self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
lhs = node.targets[0]
if isinstance(node.value, doc.Subscript):
check_slices = []
if isinstance(node.value.slice, doc.Slice):
check_slices = [node.value.slice]
elif isinstance(node.value.slice, doc.Tuple):
for p in node.value.slice.elts:
if isinstance(p, doc.Slice):
check_slices.append(p)
for s in check_slices:
if not s.step and s.upper and s.lower:
s.step = doc.Constant(
1,
None,
1,
1,
s.upper.lineno,
s.upper.end_col_offset + 1,
s.upper.lineno,
s.upper.end_col_offset + 2,
)
rhs = self.eval_expr(node.value)
if isinstance(lhs, doc.Subscript):
if isinstance(lhs.slice, doc.Tuple):
indices = []
for index in lhs.slice.elts:
indices.append(self.eval_expr(index))
else:
indices = self.eval_expr(lhs.slice)
T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
else:
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
@dispatch.register(token="tir", type_name="AugAssign")
def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
"""The augmented assign visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.AugAssign
The doc AST augmented assign node.
"""
lhs_pos = (
node.target.lineno,
node.target.col_offset,
node.target.end_lineno,
node.target.end_col_offset,
)
rhs_pos = (
node.value.lineno,
node.value.col_offset,
node.value.end_lineno,
node.value.end_col_offset,
)
node.target.ctx = doc.Load(*lhs_pos)
with self.var_table.with_frame():
lhs_name = "__tvm_tmp_value_aug_assign_lhs"
rhs_name = "__tvm_tmp_value_aug_assign_rhs"
lhs_expr = self.eval_expr(node.target)
rhs_expr = self.eval_expr(node.value)
self.var_table.add(lhs_name, lhs_expr)
self.var_table.add(rhs_name, rhs_expr)
op = doc.BinOp(
doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
node.op,
doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
*lhs_pos,
)
rhs = self.eval_expr(op)
lhs = node.target
lhs.ctx = doc.Store(*lhs_pos)
if isinstance(lhs, doc.Subscript):
if isinstance(lhs.slice, doc.Tuple):
indices = []
for index in lhs.slice.elts:
indices.append(self.eval_expr(index))
else:
indices = [self.eval_expr(lhs.slice)]
T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
else:
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
@dispatch.register(token="tir", type_name="AnnAssign")
def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
"""The annotated assign visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.AnnAssign
The doc AST annotated assign node.
"""
lhs = node.target
rhs = self.eval_expr(node.value)
ann_var = self.visit_tvm_annotation(node.annotation)
if not isinstance(ann_var, Var):
self.report_error(node.annotation, "Annotation should be Var")
self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
frame = T.LetStmt(rhs, var=ann_var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
@dispatch.register(token="tir", type_name="With")
def visit_with(self: Parser, node: doc.With) -> None:
"""The with visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.With
The doc AST with node.
"""
with contextlib.ExitStack() as stack:
stack.enter_context(self.var_table.with_frame())
for item in node.items:
frame = self.eval_expr(item.context_expr)
if not isinstance(frame, Frame):
self.report_error(item.context_expr,
"Invalid context expression in the with-statement.")
rhs = stack.enter_context(frame)
if item.optional_vars is not None:
self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
self.visit_body(node.body)
@dispatch.register(token="tir", type_name="FunctionDef")
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
"""The function definition visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.FunctionDef
The doc AST function definition node.
"""
supplied_annotation = self.function_annotations
func_annotation = supplied_annotation.get(node.name, {})
privacy = find_decorator_annotation(node, "private", default=False)
self.function_annotations = None
with self.var_table.with_frame():
self.var_table.add("range", T.serial)
with T.prim_func(is_private=privacy):
T.func_name(node.name)
if node.returns is not None:
ret_type = self.eval_expr(node.returns)
if callable(ret_type):
ret_type = PrimType(ret_type().dtype)
T.func_ret(ret_type)
with self.with_dispatch_token("tir"):
# TODO: handle different types of arguments:
# - vararg: arg | None
# - kwonlyargs: list[arg]
# - kw_defaults: list[expr | None]
# - kwarg: arg | None
# - defaults: list[expr]
# - posonlyargs: list[arg]
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation required for function parameters.")
try:
ann = self.eval_expr(arg.annotation)
if callable(ann):
ann = ann()
except Exception: # pylint: disable=broad-except
ann = func_annotation.get(arg.arg, None)
if ann is None:
raise
param = T.arg(arg.arg, ann)
self.var_table.add(arg.arg, param)
self.visit_body(node.body)
self.function_annotations = supplied_annotation
@dispatch.register(token="tir", type_name="tvm_annotation")
def visit_tvm_annotation(self: Parser, node: doc.expr):
"""The TVM annotation visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.expr
The doc AST expr node.
"""
annotation = self.eval_expr(node)
if callable(annotation):
annotation = annotation()
return annotation
@dispatch.register(token="tir", type_name="Expr")
def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
"""The expr statement visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.Expr
The doc AST Expr node.
"""
res = self.eval_expr(node.value)
if res is None:
pass
elif isinstance(res, Frame):
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
elif isinstance(res, PrimExpr):
T.evaluate(res)
elif isinstance(res, (int, bool)):
T.evaluate(tvm.tir.const(res))
elif isinstance(res, (tvm.relay.Call, tvm.relax.Call)) and not res.args:
# Using GlobalVar.__call__ with no arguments is ambiguous, as
# each IR has a different function Call representation. If
# this occurs, convert to the TIR representation.
T.evaluate(tvm.tir.call_tir(res.op))
elif isinstance(res, str):
# Ignore docstrings
pass
else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")
@dispatch.register(token="tir", type_name="If")
def visit_if(self: Parser, node: doc.If) -> None:
"""The if visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.If
The doc AST if node.
"""
with self.var_table.with_frame():
condition = self.eval_expr(node.test)
if isinstance(condition, bool):
if condition:
self.visit_body(node.body)
elif node.orelse:
self.visit_body(node.orelse)
else:
with T.If(self.eval_expr(node.test)):
with T.Then():
with self.var_table.with_frame():
self.visit_body(node.body)
if node.orelse:
with T.Else():
with self.var_table.with_frame():
self.visit_body(node.orelse)
@dispatch.register(token="tir", type_name="Assert")
def visit_assert(self: Parser, node: doc.Assert) -> None:
"""The assert visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.Assert
The doc AST assert node.
"""
cond = self.eval_expr(node.test)
msg = self.eval_expr(node.msg)
frame = T.Assert(cond, msg)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
@dispatch.register(token="tir", type_name="Return")
def visit_return(self: Parser, node: doc.Return) -> None:
"""The return visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.Return
The doc AST return node.
"""
value = self.eval_expr(node.value)
T.evaluate(tvm.tir.ret(value))
@dispatch.register(token="tir", type_name="tvm_declare_function")
def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar:
"""The function declaration step for tir
Parameters
----------
self : Parser
The visiting parser.
node : doc.Return
The doc AST return node.
"""
supplied_annotation = self.function_annotations
func_annotation = supplied_annotation.get(node.name, {})
ret_type = None
with self.var_table.with_frame():
if node.returns is not None:
ret_type = self.eval_expr(node.returns)
if callable(ret_type):
ret_type = PrimType(ret_type().dtype)
arg_annotations = []
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation required for function parameters.")
try:
ann = self.eval_expr(arg.annotation)
if callable(ann):
ann = ann()
except Exception: # pylint: disable=broad-except
ann = func_annotation.get(arg.arg, None)
if ann is None:
raise
IRBuilder.name(arg.arg, ann)
arg_annotations.append(ann)
func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)
......@@ -38,7 +38,10 @@ class ConvertTorch:
)
ins_idx = 0
args = []
device = torch.cuda.current_device()
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment