Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
78 additions
and
120 deletions
+78
-120
tilelang/transform/simplify.py
tilelang/transform/simplify.py
+0
-1
tilelang/utils/deprecated.py
tilelang/utils/deprecated.py
+3
-5
tilelang/utils/language.py
tilelang/utils/language.py
+8
-11
tilelang/utils/sparse.py
tilelang/utils/sparse.py
+17
-32
tilelang/utils/target.py
tilelang/utils/target.py
+2
-3
tilelang/utils/tensor.py
tilelang/utils/tensor.py
+27
-41
version_provider.py
version_provider.py
+21
-27
No files found.
tilelang/transform/simplify.py
View file @
29051439
...
...
@@ -51,7 +51,6 @@ def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc |
# Decorator to simplify the output of a function
def
simplify_prim_func
(
func
:
Callable
)
->
Callable
:
def
wrapper
(
*
args
,
**
kwargs
):
stmt
:
PrimFunc
|
IRModule
=
(
func
)(
*
args
,
**
kwargs
)
return
_Simplify
(
stmt
)
...
...
tilelang/utils/deprecated.py
View file @
29051439
def
deprecated_warning
(
method_name
:
str
,
new_method_name
:
str
,
phaseout_version
:
str
=
None
):
"""A function to indicate that a method is deprecated
"""
"""A function to indicate that a method is deprecated"""
import
warnings
# pylint: disable=import-outside-toplevel, import-error
warnings
.
warn
(
f
"
{
method_name
}
is deprecated, use
{
new_method_name
}
instead"
+
(
f
" and will be removed in
{
phaseout_version
}
"
if
phaseout_version
else
""
),
f
"
{
method_name
}
is deprecated, use
{
new_method_name
}
instead"
+
(
f
" and will be removed in
{
phaseout_version
}
"
if
phaseout_version
else
""
),
DeprecationWarning
,
stacklevel
=
2
,
)
...
...
@@ -30,7 +29,6 @@ def deprecated(
import
functools
# pylint: disable=import-outside-toplevel
def
_deprecate
(
func
):
@
functools
.
wraps
(
func
)
def
_wrapper
(
*
args
,
**
kwargs
):
deprecated_warning
(
method_name
,
new_method_name
,
phaseout_version
)
...
...
tilelang/utils/language.py
View file @
29051439
...
...
@@ -24,8 +24,7 @@ def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) ->
elif
isinstance
(
buffer_or_load_or_region
,
(
tir
.
BufferLoad
,
tir
.
BufferRegion
)):
return
buffer_or_load_or_region
.
buffer
else
:
raise
TypeError
(
f
"Expected Buffer, BufferLoad, or BufferRegion, got
{
type
(
buffer_or_load_or_region
)
}
"
)
raise
TypeError
(
f
"Expected Buffer, BufferLoad, or BufferRegion, got
{
type
(
buffer_or_load_or_region
)
}
"
)
def
is_global
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
...
...
@@ -153,14 +152,12 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
"""
if
not
isinstance
(
ir_module
,
IRModule
):
raise
ValueError
(
"Not supported type: "
,
type
(
ir_module
))
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
(
"The optimized module should only have one global variable for default schedule."
)
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
"The optimized module should only have one global variable for default schedule."
func
=
list
(
ir_module
.
functions
.
values
())[
0
]
return
func
def
get_buffer_region_from_load
(
buffer_load
:
tir
.
BufferLoad
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
tir
.
BufferRegion
|
None
:
def
get_buffer_region_from_load
(
buffer_load
:
tir
.
BufferLoad
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
tir
.
BufferRegion
|
None
:
"""
Get the buffer region from a buffer load.
...
...
@@ -193,9 +190,9 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad,
return
None
def
to_buffer_region
(
obj
:
Buffer
|
BufferLoad
|
BufferRegion
|
tir
.
Var
,
access_type
:
str
=
"rw"
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
PrimExpr
|
BufferRegion
:
def
to_buffer_region
(
obj
:
Buffer
|
BufferLoad
|
BufferRegion
|
tir
.
Var
,
access_type
:
str
=
"rw"
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
PrimExpr
|
BufferRegion
:
"""
Convert to/from the tl.region representation.
...
...
@@ -203,6 +200,7 @@ def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var,
- tl.region Call -> returns the decoded BufferRegion for analysis
"""
from
tilelang.language.frame
import
has_let_value
,
get_let_value
if
isinstance
(
obj
,
tir
.
Var
)
and
has_let_value
(
obj
):
obj
=
get_let_value
(
obj
)
# Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis
...
...
@@ -279,8 +277,7 @@ def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
return
strides
def
retrive_ptr_from_buffer_region
(
buffer_or_load_or_region
:
Buffer
|
BufferLoad
|
BufferRegion
,
access_type
:
str
=
"r"
)
->
PrimExpr
:
def
retrive_ptr_from_buffer_region
(
buffer_or_load_or_region
:
Buffer
|
BufferLoad
|
BufferRegion
,
access_type
:
str
=
"r"
)
->
PrimExpr
:
if
isinstance
(
buffer_or_load_or_region
,
Buffer
):
return
buffer_or_load_or_region
.
access_ptr
(
access_type
)
elif
isinstance
(
buffer_or_load_or_region
,
BufferLoad
):
...
...
tilelang/utils/sparse.py
View file @
29051439
...
...
@@ -15,7 +15,7 @@ os.makedirs(_CACHE_DIR, exist_ok=True)
def
_get_cached_lib
():
name
=
'
compress_lib
'
name
=
"
compress_lib
"
if
os
.
path
.
exists
(
os
.
path
.
join
(
_CACHE_DIR
,
f
"
{
name
}
.so"
)):
try
:
...
...
@@ -32,24 +32,22 @@ def _get_cached_lib():
name
=
name
,
sources
=
[
compress_util
],
extra_cuda_cflags
=
[
'
-O2
'
,
'
-std=c++17
'
,
'
-lineinfo
'
,
f
'
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
'
,
f
'
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
/../tools/util/include
'
,
'
-arch=sm_90
'
,
"
-O2
"
,
"
-std=c++17
"
,
"
-lineinfo
"
,
f
"
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
"
,
f
"
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
/../tools/util/include
"
,
"
-arch=sm_90
"
,
],
build_directory
=
_CACHE_DIR
,
)
def
compress_sm90
(
A
:
torch
.
Tensor
,
block_k
:
int
,
transposed
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
compress_sm90
(
A
:
torch
.
Tensor
,
block_k
:
int
,
transposed
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
block_k
>
128
:
block_k
=
128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings
.
warn
(
f
"block_k
{
block_k
}
is too large, set to 128 for sm90 compression."
,
stacklevel
=
2
)
warnings
.
warn
(
f
"block_k
{
block_k
}
is too large, set to 128 for sm90 compression."
,
stacklevel
=
2
)
# Load the library (will use cache if available)
compress_lib
=
_get_cached_lib
()
...
...
@@ -60,8 +58,9 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
try
:
from
torch.sparse
import
to_sparse_semi_structured
,
SparseSemiStructuredTensor
except
ImportError
as
err
:
raise
ImportError
(
"SparseSemiStructuredTensor is not available in this version of PyTorch. "
"Please install a compatible version."
)
from
err
raise
ImportError
(
"SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version."
)
from
err
orig_val
=
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
try
:
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
True
...
...
@@ -73,10 +72,7 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
orig_val
def
compress
(
A
:
torch
.
Tensor
,
transposed
:
bool
,
arch
:
str
|
None
=
None
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
compress
(
A
:
torch
.
Tensor
,
transposed
:
bool
,
arch
:
str
|
None
=
None
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compress a tensor using the appropriate method based on the CUDA architecture.
"""
...
...
@@ -101,11 +97,10 @@ def compress(A: torch.Tensor,
A_sp
=
A_sp
.
t
().
contiguous
()
return
A_sp
,
E
else
:
raise
ValueError
(
f
"Unsupported CUDA compute version:
{
compute_version
}
. "
"Supported versions are sm_80 and sm_90."
)
raise
ValueError
(
f
"Unsupported CUDA compute version:
{
compute_version
}
. Supported versions are sm_80 and sm_90."
)
def
randn_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
,
transposed
:
bool
=
False
):
def
randn_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
,
transposed
:
bool
=
False
):
"""
Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
...
...
@@ -127,13 +122,7 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
return
tensor
.
to
(
dtype
)
# dtype like float8 might not have randn kernel
def
randint_semi_sparse
(
M
:
int
,
K
:
int
,
low
:
int
,
high
:
int
,
dtype
=
torch
.
int32
,
device
=
'cuda'
,
transposed
:
bool
=
False
):
def
randint_semi_sparse
(
M
:
int
,
K
:
int
,
low
:
int
,
high
:
int
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
transposed
:
bool
=
False
):
"""
Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
...
...
@@ -157,11 +146,7 @@ def randint_semi_sparse(M: int,
return
tensor
def
arange_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
'cuda'
,
transposed
:
bool
=
False
):
def
arange_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
"cuda"
,
transposed
:
bool
=
False
):
"""
Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
...
...
tilelang/utils/target.py
View file @
29051439
...
...
@@ -56,11 +56,10 @@ def check_metal_availability() -> bool:
if
not
mac_release
:
return
False
# todo: check torch version?
return
arch
==
'
arm64
'
return
arch
==
"
arm64
"
def
determine_target
(
target
:
str
|
Target
|
Literal
[
"auto"
]
=
"auto"
,
return_object
:
bool
=
False
)
->
str
|
Target
:
def
determine_target
(
target
:
str
|
Target
|
Literal
[
"auto"
]
=
"auto"
,
return_object
:
bool
=
False
)
->
str
|
Target
:
"""
Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
...
...
tilelang/utils/tensor.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
enum
import
Enum
import
torch
from
tvm
import
tir
...
...
@@ -17,7 +18,7 @@ def is_float8_dtype(dtype: torch.dtype) -> bool:
def
fp8_remove_negative_zeros_
(
tensor
:
torch
.
Tensor
):
assert
is_float8_dtype
(
tensor
.
dtype
),
"Input tensor must be of float8 dtype"
bits
=
tensor
.
view
(
torch
.
uint8
)
zeros_mask
=
(
tensor
==
0
)
zeros_mask
=
tensor
==
0
bits
[
zeros_mask
]
=
0x00
...
...
@@ -33,26 +34,21 @@ class TensorSupplyType(Enum):
def
map_torch_type
(
intype
:
str
)
->
torch
.
dtype
:
if
intype
==
"float8_e4m3"
:
assert
hasattr
(
torch
,
"float8_e4m3fn"
),
\
"torch.float8_e4m3fn is not supported in this version of torch"
\
"Please upgrade torch >= 2.1.0"
assert
hasattr
(
torch
,
"float8_e4m3fn"
),
"torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return
torch
.
float8_e4m3fn
elif
intype
==
"float8_e5m2"
:
assert
hasattr
(
torch
,
"float8_e5m2"
),
\
"torch.float8_e5m2 is not supported in this version of torch"
\
"Please upgrade torch >= 2.1.0"
assert
hasattr
(
torch
,
"float8_e5m2"
),
"torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return
torch
.
float8_e5m2
elif
intype
==
"e4m3fnuz_float8"
:
assert
hasattr
(
torch
,
"float8_e4m3fnuz"
),
\
"torch.float8_e4m3fnuz is not supported in this version of torch
"
\
"Please upgrade torch >= 2.2.0"
assert
hasattr
(
torch
,
"float8_e4m3fnuz"
),
(
"torch.float8_e4m3fnuz is not supported in this version of torch
Please upgrade torch >= 2.2.0"
)
return
torch
.
float8_e4m3fnuz
else
:
return
getattr
(
torch
,
intype
)
def
get_tensor_supply
(
supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Integer
):
from
tilelang.engine.param
import
KernelParam
from
.device
import
get_current_device
...
...
@@ -63,7 +59,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if
hasattr
(
param
,
"shape"
)
and
not
param
.
shape
:
raise
ValueError
(
f
"TensorType must have a shape, but got
{
type
(
param
)
}
, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
# Check if with dynamic symbolic shape
for
shape
in
param
.
shape
:
...
...
@@ -81,8 +78,7 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if
is_unsigned
:
return
torch
.
randint
(
low
=
0
,
high
=
3
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
is_float8
:
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
elif
is_boolean
:
return
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
dtype
in
{
torch
.
float16
,
torch
.
float32
,
torch
.
bfloat16
}:
...
...
@@ -103,18 +99,15 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if
is_unsigned
:
return
torch
.
randint
(
low
=
0
,
high
=
3
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
is_float8
:
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
elif
is_boolean
:
return
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
else
:
return
torch
.
randint
(
low
=-
2
,
high
=
3
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
supply_type
==
TensorSupplyType
.
Uniform
:
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
uniform_
(
-
1.0
,
1.0
).
to
(
dtype
)
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
uniform_
(
-
1.0
,
1.0
).
to
(
dtype
)
elif
supply_type
==
TensorSupplyType
.
Normal
:
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
normal_
(
-
1.0
,
1.0
).
to
(
dtype
)
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
normal_
(
-
1.0
,
1.0
).
to
(
dtype
)
elif
supply_type
==
TensorSupplyType
.
Randn
:
return
torch
.
randn
(
*
shape
,
device
=
device
).
to
(
dtype
)
elif
supply_type
==
TensorSupplyType
.
Zero
:
...
...
@@ -150,9 +143,7 @@ def _compare_attributes(
"""
def
raise_mismatch_error
(
attribute_name
:
str
,
actual_value
,
expected_value
):
raise
AssertionError
(
f
"The values for attribute '
{
attribute_name
}
' do not match:
{
actual_value
}
!=
{
expected_value
}
."
)
raise
AssertionError
(
f
"The values for attribute '
{
attribute_name
}
' do not match:
{
actual_value
}
!=
{
expected_value
}
."
)
if
actual
.
shape
!=
expected
.
shape
:
raise_mismatch_error
(
"shape"
,
actual
.
shape
,
expected
.
shape
)
...
...
@@ -163,7 +154,7 @@ def _compare_attributes(
if
actual
.
layout
!=
expected
.
layout
:
if
check_layout
:
raise_mismatch_error
(
"layout"
,
actual
.
layout
,
expected
.
layout
)
elif
(
actual
.
layout
==
torch
.
strided
and
check_stride
and
actual
.
stride
()
!=
expected
.
stride
()
)
:
elif
actual
.
layout
==
torch
.
strided
and
check_stride
and
actual
.
stride
()
!=
expected
.
stride
():
raise_mismatch_error
(
"stride()"
,
actual
.
stride
(),
expected
.
stride
())
if
check_device
and
actual
.
device
!=
expected
.
device
:
raise_mismatch_error
(
"device"
,
actual
.
device
,
expected
.
device
)
...
...
@@ -171,8 +162,7 @@ def _compare_attributes(
raise_mismatch_error
(
"dtype"
,
actual
.
dtype
,
expected
.
dtype
)
def
_equalize_attributes
(
actual
:
torch
.
Tensor
,
expected
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
_equalize_attributes
(
actual
:
torch
.
Tensor
,
expected
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Equalizes some attributes of two tensors for value comparison.
If ``actual`` and ``expected`` are ...
- ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
...
...
@@ -210,7 +200,7 @@ def _equalize_attributes(actual: torch.Tensor,
if
actual
.
layout
!=
expected
.
layout
:
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
actual
=
actual
.
to_dense
()
if
actual
.
layout
!=
torch
.
strided
else
actual
expected
=
(
expected
.
to_dense
()
if
expected
.
layout
!=
torch
.
strided
else
expected
)
expected
=
expected
.
to_dense
()
if
expected
.
layout
!=
torch
.
strided
else
expected
return
actual
,
expected
...
...
@@ -254,12 +244,8 @@ def torch_assert_close(
"""
_compare_attributes
(
tensor_a
,
tensor_b
,
check_device
=
check_device
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
check_stride
=
check_stride
)
tensor_a
,
tensor_b
,
check_device
=
check_device
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
check_stride
=
check_stride
)
tensor_a
,
tensor_b
=
_equalize_attributes
(
tensor_a
,
tensor_b
)
mismatched
=
~
torch
.
isclose
(
tensor_a
,
tensor_b
,
rtol
=
rtol
,
atol
=
atol
,
equal_nan
=
equal_nan
)
...
...
@@ -276,8 +262,7 @@ def torch_assert_close(
# Print debug information about the mismatch
if
verbose
:
print
(
f
"Number of mismatched elements:
{
num_mismatched
}
/
{
total_elements
}
"
f
"(allowed:
{
max_allowed_mismatched
}
)"
)
print
(
f
"Number of mismatched elements:
{
num_mismatched
}
/
{
total_elements
}
(allowed:
{
max_allowed_mismatched
}
)"
)
# If there are mismatched elements, print the first mismatch
if
num_mismatched
>
0
:
...
...
@@ -289,9 +274,9 @@ def torch_assert_close(
b_val
=
tensor_b
.
reshape
(
-
1
)[
flat_idx
].
item
()
abs_diff
=
abs
(
a_val
-
b_val
)
rel_diff
=
abs_diff
/
(
abs
(
b_val
)
+
1e-12
)
mismatch_info
=
(
f
"
\n
First mismatch at index
{
idx
}
: "
f
"lhs=
{
a_val
:.
6
f
}
, rhs=
{
b_val
:.
6
f
}
,
"
f
"abs_diff=
{
abs_diff
:.
6
f
}
, rel_diff=
{
rel_diff
:.
6
f
}
"
)
mismatch_info
=
(
f
"
\n
First mismatch at index
{
idx
}
: lhs=
{
a_val
:.
6
f
}
, rhs=
{
b_val
:.
6
f
}
, abs_diff=
{
abs_diff
:.
6
f
}
, rel_diff=
{
rel_diff
:.
6
f
}
"
)
else
:
mismatch_info
=
""
...
...
@@ -304,6 +289,7 @@ def torch_assert_close(
f
"
\n
Greatest absolute difference:
{
diff
.
max
().
item
()
}
, "
f
"Greatest relative difference:
{
(
diff
/
(
torch
.
abs
(
tensor_b
)
+
1e-12
)).
max
().
item
()
}
"
f
"
\n
{
base_name
}
:
{
tensor_a
}
"
f
"
\n
{
ref_name
}
:
{
tensor_b
}
"
)
f
"
\n
{
ref_name
}
:
{
tensor_b
}
"
)
else
:
return
True
version_provider.py
View file @
29051439
...
...
@@ -8,29 +8,26 @@ from functools import lru_cache
ROOT
=
Path
(
__file__
).
parent
base_version
=
(
ROOT
/
'
VERSION
'
).
read_text
().
strip
()
base_version
=
(
ROOT
/
"
VERSION
"
).
read_text
().
strip
()
# When installing a sdist,
# the installed version needs to match the sdist version,
# so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`.
# To workaround that, when building sdist,
# we do not add version label and use a file to store the git hash instead.
git_pin
=
ROOT
/
'
.git_commit.txt
'
git_pin
=
ROOT
/
"
.git_commit.txt
"
def
_read_cmake_bool
(
i
:
str
|
None
,
default
=
False
):
if
i
is
None
:
return
default
return
i
.
lower
()
not
in
(
'0'
,
'
false
'
,
'
off
'
,
'
no
'
,
'n'
,
''
)
return
i
.
lower
()
not
in
(
"0"
,
"
false
"
,
"
off
"
,
"
no
"
,
"n"
,
""
)
@
lru_cache
(
maxsize
=
1
)
def
get_git_commit_id
()
->
str
|
None
:
"""Get the current git commit hash by running git in the current file's directory."""
r
=
subprocess
.
run
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
ROOT
,
capture_output
=
True
,
encoding
=
'utf-8'
)
r
=
subprocess
.
run
([
"git"
,
"rev-parse"
,
"HEAD"
],
cwd
=
ROOT
,
capture_output
=
True
,
encoding
=
"utf-8"
)
if
r
.
returncode
==
0
:
_git
=
r
.
stdout
.
strip
()
git_pin
.
write_text
(
_git
)
...
...
@@ -41,51 +38,48 @@ def get_git_commit_id() -> str | None:
return
None
def
dynamic_metadata
(
field
:
str
,
settings
:
dict
[
str
,
object
]
|
None
=
None
,
)
->
str
:
assert
field
==
'version'
def
dynamic_metadata
(
field
:
str
,
settings
:
dict
[
str
,
object
]
|
None
=
None
)
->
str
:
assert
field
==
"version"
version
=
base_version
# generate git version for sdist
get_git_commit_id
()
if
not
_read_cmake_bool
(
os
.
environ
.
get
(
'
NO_VERSION_LABEL
'
)):
if
not
_read_cmake_bool
(
os
.
environ
.
get
(
"
NO_VERSION_LABEL
"
)):
exts
=
[]
backend
=
None
if
_read_cmake_bool
(
os
.
environ
.
get
(
'
NO_TOOLCHAIN_VERSION
'
)):
if
_read_cmake_bool
(
os
.
environ
.
get
(
"
NO_TOOLCHAIN_VERSION
"
)):
pass
elif
platform
.
system
()
==
'
Darwin
'
:
elif
platform
.
system
()
==
"
Darwin
"
:
# only on macosx_11_0_arm64, not necessary
# backend = 'metal'
pass
elif
_read_cmake_bool
(
os
.
environ
.
get
(
'
USE_ROCM
'
,
''
)):
backend
=
'
rocm
'
elif
'
USE_CUDA
'
in
os
.
environ
and
not
_read_cmake_bool
(
os
.
environ
.
get
(
'
USE_CUDA
'
)):
backend
=
'
cpu
'
elif
_read_cmake_bool
(
os
.
environ
.
get
(
"
USE_ROCM
"
,
""
)):
backend
=
"
rocm
"
elif
"
USE_CUDA
"
in
os
.
environ
and
not
_read_cmake_bool
(
os
.
environ
.
get
(
"
USE_CUDA
"
)):
backend
=
"
cpu
"
else
:
# cuda
# Read nvcc version from env.
# This is not exactly how it should be,
# but works for now if building in a nvidia/cuda image.
if
cuda_version
:
=
os
.
environ
.
get
(
'
CUDA_VERSION
'
):
major
,
minor
,
*
_
=
cuda_version
.
split
(
'.'
)
backend
=
f
'
cu
{
major
}{
minor
}
'
if
cuda_version
:
=
os
.
environ
.
get
(
"
CUDA_VERSION
"
):
major
,
minor
,
*
_
=
cuda_version
.
split
(
"."
)
backend
=
f
"
cu
{
major
}{
minor
}
"
else
:
backend
=
'
cuda
'
backend
=
"
cuda
"
if
backend
:
exts
.
append
(
backend
)
if
_read_cmake_bool
(
os
.
environ
.
get
(
'
NO_GIT_VERSION
'
)):
if
_read_cmake_bool
(
os
.
environ
.
get
(
"
NO_GIT_VERSION
"
)):
pass
elif
git_hash
:
=
get_git_commit_id
():
exts
.
append
(
f
'
git
{
git_hash
[:
8
]
}
'
)
exts
.
append
(
f
"
git
{
git_hash
[:
8
]
}
"
)
else
:
exts
.
append
(
'
gitunknown
'
)
exts
.
append
(
"
gitunknown
"
)
if
exts
:
version
+=
'+'
+
'.'
.
join
(
exts
)
version
+=
"+"
+
"."
.
join
(
exts
)
return
version
...
...
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment