"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "fe995d05dbb16d1783694dbd79d2b68bc76fc886"
Unverified Commit b1922518 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Minor] Remove from __future__ import annotations for python 3.8 (#1273)

parent 220c3236
"""Annotation helpers exposed on the TileLang language surface.""" """Annotation helpers exposed on the TileLang language surface."""
from __future__ import annotations
from typing import Callable from typing import Callable
from tilelang.layout import Layout from tilelang.layout import Layout
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Literal
from tilelang import language as T from tilelang import language as T
from tilelang.utils.language import ( from tilelang.utils.language import (
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op from tvm.tir import PrimExpr, Buffer, op
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import has_let_value, get_let_value from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
......
"""Override the LetFrame to print a message when entering the frame.""" """Override the LetFrame to print a message when entering the frame."""
from __future__ import annotations from __future__ import annotations
from tvm.ffi import register_object as _register_object from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
from tvm.ir import Range from tvm.ir import Range
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import deque
from tvm import tir from tvm import tir
from tvm.tir import Var from tvm.tir import Var
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from tvm import tir from tvm import tir
from tvm.tir import IntImm from tvm.tir import IntImm
......
"""TVMScript parser overrides tailored for TileLang.""" """TVMScript parser overrides tailored for TileLang."""
from __future__ import annotations
from functools import partial from functools import partial
from tvm.script.ir_builder import tir as T from tvm.script.ir_builder import tir as T
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
# This file is modified from the original version, # This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/). # which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration""" """The tir expression operation registration"""
from __future__ import annotations
from tvm import tir from tvm import tir
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm from tvm.tir import IntImm
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any, SupportsIndex, TYPE_CHECKING from typing import Any, SupportsIndex, TYPE_CHECKING
from collections.abc import Sequence from collections.abc import Sequence
from typing_extensions import Self from typing_extensions import Self
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import copy, macro, alloc_shared, alloc_fragment from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.language.utils import buffer_to_tile_region from tilelang.language.utils import buffer_to_tile_region
......
from __future__ import annotations
import tvm.script.ir_builder.tir.ir as _ir import tvm.script.ir_builder.tir.ir as _ir
from tvm.script.ir_builder.tir import frame from tvm.script.ir_builder.tir import frame
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
......
from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir from tvm import tir
from tvm.tir import PrimExpr, Buffer, BufferLoad, op from tvm.tir import PrimExpr, Buffer, BufferLoad, op
......
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager, AbstractContextManager from contextlib import contextmanager, AbstractContextManager
from dataclasses import dataclass from dataclasses import dataclass
import inspect import inspect
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm.ffi import register_object from tvm.ffi import register_object
from tilelang import _ffi_api from tilelang import _ffi_api
......
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm import tvm
import tvm_ffi import tvm_ffi
from tvm.ir import Range from tvm.ir import Range
...@@ -124,7 +122,7 @@ class Fragment(Layout): ...@@ -124,7 +122,7 @@ class Fragment(Layout):
def repeat(self, def repeat(self,
repeats, repeats,
repeat_on_thread: bool = False, repeat_on_thread: bool = False,
lower_dim_first: bool = True) -> Fragment: lower_dim_first: bool = True) -> 'Fragment':
""" """
Returns a new Fragment that repeats the iteration space a given number of times. Returns a new Fragment that repeats the iteration space a given number of times.
...@@ -144,7 +142,7 @@ class Fragment(Layout): ...@@ -144,7 +142,7 @@ class Fragment(Layout):
""" """
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first)
def replicate(self, replicate: int) -> Fragment: def replicate(self, replicate: int) -> 'Fragment':
""" """
Replicate the Fragment across a new thread dimension. Replicate the Fragment across a new thread dimension.
...@@ -160,7 +158,7 @@ class Fragment(Layout): ...@@ -160,7 +158,7 @@ class Fragment(Layout):
""" """
return _ffi_api.Fragment_replicate(self, replicate) return _ffi_api.Fragment_replicate(self, replicate)
def condense_rep_var(self) -> Fragment: def condense_rep_var(self) -> 'Fragment':
""" """
Condense or fold the replicate variable into the existing iteration space. Condense or fold the replicate variable into the existing iteration space.
This operation may be used to reduce dimensionality if the replicate variable This operation may be used to reduce dimensionality if the replicate variable
...@@ -207,7 +205,7 @@ class Fragment(Layout): ...@@ -207,7 +205,7 @@ class Fragment(Layout):
""" """
return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
def is_equal(self, other: Fragment) -> bool: def is_equal(self, other: 'Fragment') -> bool:
""" """
Check if the current fragment is equal to another fragment. Check if the current fragment is equal to another fragment.
""" """
......
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations from __future__ import annotations
import tvm import tvm
import tilelang.language as T import tilelang.language as T
import warnings import warnings
......
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm_ffi import tvm_ffi
from tvm.ir import Node, Range from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tvm.tir import IterVar, Var, PrimExpr, IndexMap
...@@ -122,7 +120,7 @@ class Layout(Node): ...@@ -122,7 +120,7 @@ class Layout(Node):
# Map the provided indices using the constructed index mapping # Map the provided indices using the constructed index mapping
return index_map.map_indices(indices) return index_map.map_indices(indices)
def inverse(self) -> Layout: def inverse(self) -> 'Layout':
""" """
Compute the inverse of the current layout transformation. Compute the inverse of the current layout transformation.
...@@ -133,7 +131,7 @@ class Layout(Node): ...@@ -133,7 +131,7 @@ class Layout(Node):
""" """
return _ffi_api.Layout_inverse(self) return _ffi_api.Layout_inverse(self)
def is_equal(self, other: Layout) -> bool: def is_equal(self, other: 'Layout') -> bool:
""" """
Check if the current layout is equal to another layout. Check if the current layout is equal to another layout.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment