"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "23ef67a701439068e7f7a69cdf4fb372da63e9f5"
Unverified Commit 7d961892 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Improve Python3.9 compatibility for ParamSpec and Self (#1190)

* [Feature] Enhance fill operation to support various buffer types

- Added support for `BufferLoad` in the `fill` function to handle different buffer types.
- Updated `Fill` class to process region descriptors and buffer regions, improving flexibility in buffer handling.
- Introduced checks for static bounds in region definitions to ensure safety during operations.
- Refactored loop induction variable handling in `FillNode` to accommodate sliced regions.

* lint fix

* [Refactor] Improve Python compatibility for ParamSpec and Self

- Added compatibility handling for ParamSpec and Self to support Python versions below 3.10 and 3.11 respectively.
- Updated type annotations across multiple files to ensure consistent usage of typing features.

* [Update] Require Python 3.9 and enhance type annotations

- Updated the minimum required Python version from 3.8 to 3.9 in `pyproject.toml`.
- Removed references to Python 3.8 in classifiers.
- Changed type annotations from `int | None` to `Optional[int]` in multiple example files for better clarity and compatibility.
- Improved import statements to use `collections.abc` for `Iterable` and `contextlib` for `AbstractContextManager` in relevant files.

* [Refactor] Update import statements to enhance type annotations

- Replaced imports from `typing` with `collections.abc` for `Iterable` and `Mapping` in relevant files to improve compatibility and clarity.
- Updated the caching decorator from `functools.lru_cache` to `functools.cache` for better performance in the C++ compiler retrieval function.
- Adjusted import statements in the language proxy file to maintain consistency in type annotations.

* disable rocm rs nt test.

* lint fix
parent a03df604
......@@ -5,6 +5,7 @@ import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit
......@@ -94,7 +95,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
......@@ -130,7 +131,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
......
......@@ -5,6 +5,7 @@ import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit
......@@ -93,7 +94,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
......@@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
......@@ -444,7 +444,7 @@ def main(BATCH: int = 1,
N_CTX: int = 512,
D_HEAD: int = 64,
groups: int = 2,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
......
......@@ -272,7 +272,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
......
......@@ -440,7 +440,7 @@ def main(BATCH: int = 1,
H: int = 1,
N_CTX: int = 512,
D_HEAD: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
......
......@@ -253,7 +253,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
......@@ -263,7 +263,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
......
......@@ -2,7 +2,7 @@
name = "tilelang"
description = "A tile level programming language to generate high performance code."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT"
......@@ -14,7 +14,6 @@ classifiers = [
"Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
......@@ -118,7 +117,7 @@ skip = [
]
[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 100
output-format = "full"
......
......@@ -14,6 +14,7 @@
#include "../layout/layout.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"
......@@ -275,17 +276,34 @@ private:
auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) {
ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>();
arg0_call &&
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
if (auto arg0_call = op->args[0].as<Call>()) {
// Case 1: tl.region(...) — extract buffer var from its first arg
if (arg0_call.value()->op.same_as(RegionOp::Get())) {
ICHECK(!arg0_call.value()->args.empty());
if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
Var var = bl->buffer->data;
if (reducer_info_map_.count(var)) {
ICHECK(inside_reducer_range_.count(var) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var,
reducer_info_map_.Get(var).value());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var.value(),
reducer_info_map_.Get(var.value()).value());
inside_reducer_range_.Set(
var.value(), reducer_info_map_.Get(var.value()).value());
}
}
}
} else if (op->op.same_as(FinalizeReducerOp::Get())) {
......
......@@ -223,29 +223,26 @@ def run_gemm_rs(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_rocm
def test_gemm_rs_f16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
@tilelang.testing.requires_rocm
def test_gemm_rs_bf16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
@tilelang.testing.requires_rocm
def test_gemm_rs_bf16bf16f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_f16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16bf16f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -14,7 +14,12 @@ from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar)
from typing import (Callable, Generic, Literal, Any, TypeVar)
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tqdm.auto import tqdm
import logging
import concurrent.futures
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import math
from queue import PriorityQueue
from typing import Iterable
from collections.abc import Iterable
import numpy as np
import tvm
......
from __future__ import annotations
from typing import Mapping
from collections.abc import Mapping
from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal
from tvm import arith, tir
......
......@@ -64,7 +64,7 @@ def get_cc():
return None
@functools.lru_cache(maxsize=None)
@functools.cache
def get_cplus_compiler():
"""Return the path to the default C/C++ compiler.
......
......@@ -11,12 +11,16 @@ from typing import (
Any,
Callable,
Generic,
Iterable,
ParamSpec,
TypeVar,
overload,
Literal,
)
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target
......
from __future__ import annotations
from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar
from typing import Any, Callable, Generic, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING
from typing import Any, SupportsIndex, TYPE_CHECKING
from collections.abc import Sequence
from typing_extensions import Self
from tvm import tir
......
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import Callable, ContextManager, Generic, Iterable, Any, Literal, ParamSpec, TypeVar
from typing import Callable, Generic, Any, Literal, TypeVar
from contextlib import AbstractContextManager
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
import inspect
# from .utils import get_ast, get_compiled_object
from . import utils
......@@ -223,7 +230,7 @@ class BaseBuilder:
def ret(self, value: Any) -> Any:
return value
def ctx_with(self, ctx: ContextManager[Any]) -> ContextManager[Any]:
def ctx_with(self, ctx: AbstractContextManager[Any]) -> AbstractContextManager[Any]:
return ctx
def assert_expr(self, cond: Any, msg: Any):
......
from __future__ import annotations
from contextlib import contextmanager
from contextlib import contextmanager, AbstractContextManager
from dataclasses import dataclass
import inspect
......@@ -12,7 +12,12 @@ import tvm
from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, ParamSpec, Self, TypeVar, ForwardRef
from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union
# Python 3.9 compatibility for ParamSpec and Self
try:
from typing import ParamSpec, Self
except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec
from typing_extensions import ParamSpec, Self
from . import dtypes as dt
import threading
import logging
......@@ -95,8 +100,10 @@ class BreakFrame(Frame):
...
ContinueOrBreak = ContinueFrame | BreakFrame
AnyFrame = tir.frame.IRBuilderFrame | Frame
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
AnyFrame = Union[tir.frame.IRBuilderFrame, Frame]
TIR_CONTROL_FRAME = (
tir.frame.WhileFrame,
......@@ -160,7 +167,7 @@ class Builder(BaseBuilder):
if isinstance(f, frame):
return idx
def enter_frame(self, frame: ContextManager):
def enter_frame(self, frame: AbstractContextManager[Any]):
self.frames.append(frame)
return frame.__enter__()
......@@ -173,7 +180,7 @@ class Builder(BaseBuilder):
stacklevel=3)
@contextmanager
def with_frame(self, frame: ContextManager | None):
def with_frame(self, frame: AbstractContextManager[Any] | None):
pop_idx = len(self.frames)
yield self.enter_frame(frame)
while len(self.frames) > pop_idx:
......
......@@ -2,12 +2,13 @@ from tilelang import tvm
from tvm import ir
import torch
import ctypes
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
dtype = tvm.DataType
AnyDType = ir.Type | str | type | torch.dtype | dtype
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
_dtype_cvt = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment