Unverified Commit bc9623fc authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Py38] Revert typing and parser updates for Python 3.8 compatibility (#850)

* Update submodule TVM to commit 872e32c1 and adjust type hints in nvcc.py and utils.py for compatibility with Python typing standards.

* Update requirements.txt to specify ml_dtypes without a version constraint, indicating that versions greater than 0.5.1 are needed for fp4 support.
parent 8cc2ab22
Subproject commit 6051f6dbdd741be340f47f944cd433f04ed18a8d Subproject commit 872e32c16d5bd0826b60f73f55af9e694d86a5a1
...@@ -4,6 +4,8 @@ numpy>=1.23.5 ...@@ -4,6 +4,8 @@ numpy>=1.23.5
tqdm>=4.62.3 tqdm>=4.62.3
typing_extensions>=4.10.0 typing_extensions>=4.10.0
cloudpickle cloudpickle
ml_dtypes>=0.5.3 # mldtypes should be greater than 0.5.1
# if you want to enable fp4
ml_dtypes
psutil psutil
torch torch
...@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs ...@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs
import os import os
import subprocess import subprocess
import warnings import warnings
from typing import Tuple
from tilelang.env import CUDA_HOME from tilelang.env import CUDA_HOME
import tvm.ffi import tvm.ffi
...@@ -298,7 +299,7 @@ def get_target_compute_version(target=None): ...@@ -298,7 +299,7 @@ def get_target_compute_version(target=None):
"Try specifying it by adding '-arch=sm_xx' to your target.") "Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version) -> tuple[int, int]: def parse_compute_version(compute_version) -> Tuple[int, int]:
"""Parse compute capability string to divide major and minor version """Parse compute capability string to divide major and minor version
Parameters Parameters
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
from typing import List
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
def index_to_coordinates(index, shape) -> list[PrimExpr]: def index_to_coordinates(index, shape) -> List[PrimExpr]:
""" """
Convert a flat (linear) index into multi-dimensional coordinates for a given shape. Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
...@@ -13,7 +14,7 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]: ...@@ -13,7 +14,7 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]:
shape (Sequence[int]): The extents of each dimension (length >= 1). shape (Sequence[int]): The extents of each dimension (length >= 1).
Returns: Returns:
list[PrimExpr]: Coordinates for each dimension in the same order as `shape`. List[PrimExpr]: Coordinates for each dimension in the same order as `shape`.
""" """
coordinates = [] coordinates = []
dims = len(shape) dims = len(shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment