Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
78 additions
and
120 deletions
+78
-120
tilelang/transform/simplify.py
tilelang/transform/simplify.py
+0
-1
tilelang/utils/deprecated.py
tilelang/utils/deprecated.py
+3
-5
tilelang/utils/language.py
tilelang/utils/language.py
+8
-11
tilelang/utils/sparse.py
tilelang/utils/sparse.py
+17
-32
tilelang/utils/target.py
tilelang/utils/target.py
+2
-3
tilelang/utils/tensor.py
tilelang/utils/tensor.py
+27
-41
version_provider.py
version_provider.py
+21
-27
No files found.
tilelang/transform/simplify.py
View file @
29051439
...
...
@@ -51,7 +51,6 @@ def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc |
# Decorator to simplify the output of a function
def
simplify_prim_func
(
func
:
Callable
)
->
Callable
:
def
wrapper
(
*
args
,
**
kwargs
):
stmt
:
PrimFunc
|
IRModule
=
(
func
)(
*
args
,
**
kwargs
)
return
_Simplify
(
stmt
)
...
...
tilelang/utils/deprecated.py
View file @
29051439
def
deprecated_warning
(
method_name
:
str
,
new_method_name
:
str
,
phaseout_version
:
str
=
None
):
"""A function to indicate that a method is deprecated
"""
"""A function to indicate that a method is deprecated"""
import
warnings
# pylint: disable=import-outside-toplevel, import-error
warnings
.
warn
(
f
"
{
method_name
}
is deprecated, use
{
new_method_name
}
instead"
+
(
f
" and will be removed in
{
phaseout_version
}
"
if
phaseout_version
else
""
),
f
"
{
method_name
}
is deprecated, use
{
new_method_name
}
instead"
+
(
f
" and will be removed in
{
phaseout_version
}
"
if
phaseout_version
else
""
),
DeprecationWarning
,
stacklevel
=
2
,
)
...
...
@@ -30,7 +29,6 @@ def deprecated(
import
functools
# pylint: disable=import-outside-toplevel
def
_deprecate
(
func
):
@
functools
.
wraps
(
func
)
def
_wrapper
(
*
args
,
**
kwargs
):
deprecated_warning
(
method_name
,
new_method_name
,
phaseout_version
)
...
...
tilelang/utils/language.py
View file @
29051439
...
...
@@ -24,8 +24,7 @@ def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) ->
elif
isinstance
(
buffer_or_load_or_region
,
(
tir
.
BufferLoad
,
tir
.
BufferRegion
)):
return
buffer_or_load_or_region
.
buffer
else
:
raise
TypeError
(
f
"Expected Buffer, BufferLoad, or BufferRegion, got
{
type
(
buffer_or_load_or_region
)
}
"
)
raise
TypeError
(
f
"Expected Buffer, BufferLoad, or BufferRegion, got
{
type
(
buffer_or_load_or_region
)
}
"
)
def
is_global
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
...
...
@@ -153,14 +152,12 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
"""
if
not
isinstance
(
ir_module
,
IRModule
):
raise
ValueError
(
"Not supported type: "
,
type
(
ir_module
))
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
(
"The optimized module should only have one global variable for default schedule."
)
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
"The optimized module should only have one global variable for default schedule."
func
=
list
(
ir_module
.
functions
.
values
())[
0
]
return
func
def
get_buffer_region_from_load
(
buffer_load
:
tir
.
BufferLoad
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
tir
.
BufferRegion
|
None
:
def
get_buffer_region_from_load
(
buffer_load
:
tir
.
BufferLoad
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
tir
.
BufferRegion
|
None
:
"""
Get the buffer region from a buffer load.
...
...
@@ -193,9 +190,9 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad,
return
None
def
to_buffer_region
(
obj
:
Buffer
|
BufferLoad
|
BufferRegion
|
tir
.
Var
,
access_type
:
str
=
"rw"
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
PrimExpr
|
BufferRegion
:
def
to_buffer_region
(
obj
:
Buffer
|
BufferLoad
|
BufferRegion
|
tir
.
Var
,
access_type
:
str
=
"rw"
,
extents
:
list
[
PrimExpr
]
|
None
=
None
)
->
PrimExpr
|
BufferRegion
:
"""
Convert to/from the tl.region representation.
...
...
@@ -203,6 +200,7 @@ def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var,
- tl.region Call -> returns the decoded BufferRegion for analysis
"""
from
tilelang.language.frame
import
has_let_value
,
get_let_value
if
isinstance
(
obj
,
tir
.
Var
)
and
has_let_value
(
obj
):
obj
=
get_let_value
(
obj
)
# Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis
...
...
@@ -279,8 +277,7 @@ def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
return
strides
def
retrive_ptr_from_buffer_region
(
buffer_or_load_or_region
:
Buffer
|
BufferLoad
|
BufferRegion
,
access_type
:
str
=
"r"
)
->
PrimExpr
:
def
retrive_ptr_from_buffer_region
(
buffer_or_load_or_region
:
Buffer
|
BufferLoad
|
BufferRegion
,
access_type
:
str
=
"r"
)
->
PrimExpr
:
if
isinstance
(
buffer_or_load_or_region
,
Buffer
):
return
buffer_or_load_or_region
.
access_ptr
(
access_type
)
elif
isinstance
(
buffer_or_load_or_region
,
BufferLoad
):
...
...
tilelang/utils/sparse.py
View file @
29051439
...
...
@@ -15,7 +15,7 @@ os.makedirs(_CACHE_DIR, exist_ok=True)
def
_get_cached_lib
():
name
=
'
compress_lib
'
name
=
"
compress_lib
"
if
os
.
path
.
exists
(
os
.
path
.
join
(
_CACHE_DIR
,
f
"
{
name
}
.so"
)):
try
:
...
...
@@ -32,24 +32,22 @@ def _get_cached_lib():
name
=
name
,
sources
=
[
compress_util
],
extra_cuda_cflags
=
[
'
-O2
'
,
'
-std=c++17
'
,
'
-lineinfo
'
,
f
'
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
'
,
f
'
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
/../tools/util/include
'
,
'
-arch=sm_90
'
,
"
-O2
"
,
"
-std=c++17
"
,
"
-lineinfo
"
,
f
"
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
"
,
f
"
-I
{
env
.
CUTLASS_INCLUDE_DIR
}
/../tools/util/include
"
,
"
-arch=sm_90
"
,
],
build_directory
=
_CACHE_DIR
,
)
def
compress_sm90
(
A
:
torch
.
Tensor
,
block_k
:
int
,
transposed
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
compress_sm90
(
A
:
torch
.
Tensor
,
block_k
:
int
,
transposed
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
block_k
>
128
:
block_k
=
128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings
.
warn
(
f
"block_k
{
block_k
}
is too large, set to 128 for sm90 compression."
,
stacklevel
=
2
)
warnings
.
warn
(
f
"block_k
{
block_k
}
is too large, set to 128 for sm90 compression."
,
stacklevel
=
2
)
# Load the library (will use cache if available)
compress_lib
=
_get_cached_lib
()
...
...
@@ -60,8 +58,9 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
try
:
from
torch.sparse
import
to_sparse_semi_structured
,
SparseSemiStructuredTensor
except
ImportError
as
err
:
raise
ImportError
(
"SparseSemiStructuredTensor is not available in this version of PyTorch. "
"Please install a compatible version."
)
from
err
raise
ImportError
(
"SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version."
)
from
err
orig_val
=
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
try
:
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
True
...
...
@@ -73,10 +72,7 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
SparseSemiStructuredTensor
.
_FORCE_CUTLASS
=
orig_val
def
compress
(
A
:
torch
.
Tensor
,
transposed
:
bool
,
arch
:
str
|
None
=
None
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
compress
(
A
:
torch
.
Tensor
,
transposed
:
bool
,
arch
:
str
|
None
=
None
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compress a tensor using the appropriate method based on the CUDA architecture.
"""
...
...
@@ -101,11 +97,10 @@ def compress(A: torch.Tensor,
A_sp
=
A_sp
.
t
().
contiguous
()
return
A_sp
,
E
else
:
raise
ValueError
(
f
"Unsupported CUDA compute version:
{
compute_version
}
. "
"Supported versions are sm_80 and sm_90."
)
raise
ValueError
(
f
"Unsupported CUDA compute version:
{
compute_version
}
. Supported versions are sm_80 and sm_90."
)
def
randn_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
,
transposed
:
bool
=
False
):
def
randn_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
,
transposed
:
bool
=
False
):
"""
Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
...
...
@@ -127,13 +122,7 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
return
tensor
.
to
(
dtype
)
# dtype like float8 might not have randn kernel
def
randint_semi_sparse
(
M
:
int
,
K
:
int
,
low
:
int
,
high
:
int
,
dtype
=
torch
.
int32
,
device
=
'cuda'
,
transposed
:
bool
=
False
):
def
randint_semi_sparse
(
M
:
int
,
K
:
int
,
low
:
int
,
high
:
int
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
transposed
:
bool
=
False
):
"""
Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
...
...
@@ -157,11 +146,7 @@ def randint_semi_sparse(M: int,
return
tensor
def
arange_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
'cuda'
,
transposed
:
bool
=
False
):
def
arange_semi_sparse
(
M
:
int
,
K
:
int
,
dtype
=
torch
.
float16
,
device
=
"cuda"
,
transposed
:
bool
=
False
):
"""
Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
...
...
tilelang/utils/target.py
View file @
29051439
...
...
@@ -56,11 +56,10 @@ def check_metal_availability() -> bool:
if
not
mac_release
:
return
False
# todo: check torch version?
return
arch
==
'
arm64
'
return
arch
==
"
arm64
"
def
determine_target
(
target
:
str
|
Target
|
Literal
[
"auto"
]
=
"auto"
,
return_object
:
bool
=
False
)
->
str
|
Target
:
def
determine_target
(
target
:
str
|
Target
|
Literal
[
"auto"
]
=
"auto"
,
return_object
:
bool
=
False
)
->
str
|
Target
:
"""
Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
...
...
tilelang/utils/tensor.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
enum
import
Enum
import
torch
from
tvm
import
tir
...
...
@@ -17,7 +18,7 @@ def is_float8_dtype(dtype: torch.dtype) -> bool:
def
fp8_remove_negative_zeros_
(
tensor
:
torch
.
Tensor
):
assert
is_float8_dtype
(
tensor
.
dtype
),
"Input tensor must be of float8 dtype"
bits
=
tensor
.
view
(
torch
.
uint8
)
zeros_mask
=
(
tensor
==
0
)
zeros_mask
=
tensor
==
0
bits
[
zeros_mask
]
=
0x00
...
...
@@ -33,26 +34,21 @@ class TensorSupplyType(Enum):
def
map_torch_type
(
intype
:
str
)
->
torch
.
dtype
:
if
intype
==
"float8_e4m3"
:
assert
hasattr
(
torch
,
"float8_e4m3fn"
),
\
"torch.float8_e4m3fn is not supported in this version of torch"
\
"Please upgrade torch >= 2.1.0"
assert
hasattr
(
torch
,
"float8_e4m3fn"
),
"torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return
torch
.
float8_e4m3fn
elif
intype
==
"float8_e5m2"
:
assert
hasattr
(
torch
,
"float8_e5m2"
),
\
"torch.float8_e5m2 is not supported in this version of torch"
\
"Please upgrade torch >= 2.1.0"
assert
hasattr
(
torch
,
"float8_e5m2"
),
"torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return
torch
.
float8_e5m2
elif
intype
==
"e4m3fnuz_float8"
:
assert
hasattr
(
torch
,
"float8_e4m3fnuz"
),
\
"torch.float8_e4m3fnuz is not supported in this version of torch
"
\
"Please upgrade torch >= 2.2.0"
assert
hasattr
(
torch
,
"float8_e4m3fnuz"
),
(
"torch.float8_e4m3fnuz is not supported in this version of torch
Please upgrade torch >= 2.2.0"
)
return
torch
.
float8_e4m3fnuz
else
:
return
getattr
(
torch
,
intype
)
def
get_tensor_supply
(
supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Integer
):
from
tilelang.engine.param
import
KernelParam
from
.device
import
get_current_device
...
...
@@ -63,7 +59,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if
hasattr
(
param
,
"shape"
)
and
not
param
.
shape
:
raise
ValueError
(
f
"TensorType must have a shape, but got
{
type
(
param
)
}
, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
# Check if with dynamic symbolic shape
for
shape
in
param
.
shape
:
...
...
@@ -81,8 +78,7 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if
is_unsigned
:
return
torch
.
randint
(
low
=
0
,
high
=
3
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
is_float8
:
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
elif
is_boolean
:
return
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
dtype
in
{
torch
.
float16
,
torch
.
float32
,
torch
.
bfloat16
}:
...
...
@@ -103,18 +99,15 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if
is_unsigned
:
return
torch
.
randint
(
low
=
0
,
high
=
3
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
is_float8
:
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
return
torch
.
randint
(
low
=-
128
,
high
=
128
,
size
=
shape
,
device
=
device
,
dtype
=
torch
.
int8
).
to
(
dtype
)
elif
is_boolean
:
return
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
else
:
return
torch
.
randint
(
low
=-
2
,
high
=
3
,
size
=
shape
,
device
=
device
,
dtype
=
dtype
)
elif
supply_type
==
TensorSupplyType
.
Uniform
:
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
uniform_
(
-
1.0
,
1.0
).
to
(
dtype
)
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
uniform_
(
-
1.0
,
1.0
).
to
(
dtype
)
elif
supply_type
==
TensorSupplyType
.
Normal
:
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
normal_
(
-
1.0
,
1.0
).
to
(
dtype
)
return
torch
.
empty
(
*
shape
,
device
=
device
,
dtype
=
torch
.
float32
).
normal_
(
-
1.0
,
1.0
).
to
(
dtype
)
elif
supply_type
==
TensorSupplyType
.
Randn
:
return
torch
.
randn
(
*
shape
,
device
=
device
).
to
(
dtype
)
elif
supply_type
==
TensorSupplyType
.
Zero
:
...
...
@@ -150,9 +143,7 @@ def _compare_attributes(
"""
def
raise_mismatch_error
(
attribute_name
:
str
,
actual_value
,
expected_value
):
raise
AssertionError
(
f
"The values for attribute '
{
attribute_name
}
' do not match:
{
actual_value
}
!=
{
expected_value
}
."
)
raise
AssertionError
(
f
"The values for attribute '
{
attribute_name
}
' do not match:
{
actual_value
}
!=
{
expected_value
}
."
)
if
actual
.
shape
!=
expected
.
shape
:
raise_mismatch_error
(
"shape"
,
actual
.
shape
,
expected
.
shape
)
...
...
@@ -163,7 +154,7 @@ def _compare_attributes(
if
actual
.
layout
!=
expected
.
layout
:
if
check_layout
:
raise_mismatch_error
(
"layout"
,
actual
.
layout
,
expected
.
layout
)
elif
(
actual
.
layout
==
torch
.
strided
and
check_stride
and
actual
.
stride
()
!=
expected
.
stride
()
)
:
elif
actual
.
layout
==
torch
.
strided
and
check_stride
and
actual
.
stride
()
!=
expected
.
stride
():
raise_mismatch_error
(
"stride()"
,
actual
.
stride
(),
expected
.
stride
())
if
check_device
and
actual
.
device
!=
expected
.
device
:
raise_mismatch_error
(
"device"
,
actual
.
device
,
expected
.
device
)
...
...
@@ -171,8 +162,7 @@ def _compare_attributes(
raise_mismatch_error
(
"dtype"
,
actual
.
dtype
,
expected
.
dtype
)
def
_equalize_attributes
(
actual
:
torch
.
Tensor
,
expected
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
_equalize_attributes
(
actual
:
torch
.
Tensor
,
expected
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Equalizes some attributes of two tensors for value comparison.
If ``actual`` and ``expected`` are ...
- ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
...
...
@@ -210,7 +200,7 @@ def _equalize_attributes(actual: torch.Tensor,
if
actual
.
layout
!=
expected
.
layout
:
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
actual
=
actual
.
to_dense
()
if
actual
.
layout
!=
torch
.
strided
else
actual
expected
=
(
expected
.
to_dense
()
if
expected
.
layout
!=
torch
.
strided
else
expected
)
expected
=
expected
.
to_dense
()
if
expected
.
layout
!=
torch
.
strided
else
expected
return
actual
,
expected
...
...
@@ -254,12 +244,8 @@ def torch_assert_close(
"""
_compare_attributes
(
tensor_a
,
tensor_b
,
check_device
=
check_device
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
check_stride
=
check_stride
)
tensor_a
,
tensor_b
,
check_device
=
check_device
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
check_stride
=
check_stride
)
tensor_a
,
tensor_b
=
_equalize_attributes
(
tensor_a
,
tensor_b
)
mismatched
=
~
torch
.
isclose
(
tensor_a
,
tensor_b
,
rtol
=
rtol
,
atol
=
atol
,
equal_nan
=
equal_nan
)
...
...
@@ -276,8 +262,7 @@ def torch_assert_close(
# Print debug information about the mismatch
if
verbose
:
print
(
f
"Number of mismatched elements:
{
num_mismatched
}
/
{
total_elements
}
"
f
"(allowed:
{
max_allowed_mismatched
}
)"
)
print
(
f
"Number of mismatched elements:
{
num_mismatched
}
/
{
total_elements
}
(allowed:
{
max_allowed_mismatched
}
)"
)
# If there are mismatched elements, print the first mismatch
if
num_mismatched
>
0
:
...
...
@@ -289,9 +274,9 @@ def torch_assert_close(
b_val
=
tensor_b
.
reshape
(
-
1
)[
flat_idx
].
item
()
abs_diff
=
abs
(
a_val
-
b_val
)
rel_diff
=
abs_diff
/
(
abs
(
b_val
)
+
1e-12
)
mismatch_info
=
(
f
"
\n
First mismatch at index
{
idx
}
: "
f
"lhs=
{
a_val
:.
6
f
}
, rhs=
{
b_val
:.
6
f
}
,
"
f
"abs_diff=
{
abs_diff
:.
6
f
}
, rel_diff=
{
rel_diff
:.
6
f
}
"
)
mismatch_info
=
(
f
"
\n
First mismatch at index
{
idx
}
: lhs=
{
a_val
:.
6
f
}
, rhs=
{
b_val
:.
6
f
}
, abs_diff=
{
abs_diff
:.
6
f
}
, rel_diff=
{
rel_diff
:.
6
f
}
"
)
else
:
mismatch_info
=
""
...
...
@@ -304,6 +289,7 @@ def torch_assert_close(
f
"
\n
Greatest absolute difference:
{
diff
.
max
().
item
()
}
, "
f
"Greatest relative difference:
{
(
diff
/
(
torch
.
abs
(
tensor_b
)
+
1e-12
)).
max
().
item
()
}
"
f
"
\n
{
base_name
}
:
{
tensor_a
}
"
f
"
\n
{
ref_name
}
:
{
tensor_b
}
"
)
f
"
\n
{
ref_name
}
:
{
tensor_b
}
"
)
else
:
return
True
version_provider.py
View file @
29051439
...
...
@@ -8,29 +8,26 @@ from functools import lru_cache
ROOT
=
Path
(
__file__
).
parent
base_version
=
(
ROOT
/
'
VERSION
'
).
read_text
().
strip
()
base_version
=
(
ROOT
/
"
VERSION
"
).
read_text
().
strip
()
# When installing a sdist,
# the installed version needs to match the sdist version,
# so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`.
# To workaround that, when building sdist,
# we do not add version label and use a file to store the git hash instead.
git_pin
=
ROOT
/
'
.git_commit.txt
'
git_pin
=
ROOT
/
"
.git_commit.txt
"
def
_read_cmake_bool
(
i
:
str
|
None
,
default
=
False
):
if
i
is
None
:
return
default
return
i
.
lower
()
not
in
(
'0'
,
'
false
'
,
'
off
'
,
'
no
'
,
'n'
,
''
)
return
i
.
lower
()
not
in
(
"0"
,
"
false
"
,
"
off
"
,
"
no
"
,
"n"
,
""
)
@
lru_cache
(
maxsize
=
1
)
def
get_git_commit_id
()
->
str
|
None
:
"""Get the current git commit hash by running git in the current file's directory."""
r
=
subprocess
.
run
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
ROOT
,
capture_output
=
True
,
encoding
=
'utf-8'
)
r
=
subprocess
.
run
([
"git"
,
"rev-parse"
,
"HEAD"
],
cwd
=
ROOT
,
capture_output
=
True
,
encoding
=
"utf-8"
)
if
r
.
returncode
==
0
:
_git
=
r
.
stdout
.
strip
()
git_pin
.
write_text
(
_git
)
...
...
@@ -41,51 +38,48 @@ def get_git_commit_id() -> str | None:
return
None
def
dynamic_metadata
(
field
:
str
,
settings
:
dict
[
str
,
object
]
|
None
=
None
,
)
->
str
:
assert
field
==
'version'
def
dynamic_metadata
(
field
:
str
,
settings
:
dict
[
str
,
object
]
|
None
=
None
)
->
str
:
assert
field
==
"version"
version
=
base_version
# generate git version for sdist
get_git_commit_id
()
if
not
_read_cmake_bool
(
os
.
environ
.
get
(
'
NO_VERSION_LABEL
'
)):
if
not
_read_cmake_bool
(
os
.
environ
.
get
(
"
NO_VERSION_LABEL
"
)):
exts
=
[]
backend
=
None
if
_read_cmake_bool
(
os
.
environ
.
get
(
'
NO_TOOLCHAIN_VERSION
'
)):
if
_read_cmake_bool
(
os
.
environ
.
get
(
"
NO_TOOLCHAIN_VERSION
"
)):
pass
elif
platform
.
system
()
==
'
Darwin
'
:
elif
platform
.
system
()
==
"
Darwin
"
:
# only on macosx_11_0_arm64, not necessary
# backend = 'metal'
pass
elif
_read_cmake_bool
(
os
.
environ
.
get
(
'
USE_ROCM
'
,
''
)):
backend
=
'
rocm
'
elif
'
USE_CUDA
'
in
os
.
environ
and
not
_read_cmake_bool
(
os
.
environ
.
get
(
'
USE_CUDA
'
)):
backend
=
'
cpu
'
elif
_read_cmake_bool
(
os
.
environ
.
get
(
"
USE_ROCM
"
,
""
)):
backend
=
"
rocm
"
elif
"
USE_CUDA
"
in
os
.
environ
and
not
_read_cmake_bool
(
os
.
environ
.
get
(
"
USE_CUDA
"
)):
backend
=
"
cpu
"
else
:
# cuda
# Read nvcc version from env.
# This is not exactly how it should be,
# but works for now if building in a nvidia/cuda image.
if
cuda_version
:
=
os
.
environ
.
get
(
'
CUDA_VERSION
'
):
major
,
minor
,
*
_
=
cuda_version
.
split
(
'.'
)
backend
=
f
'
cu
{
major
}{
minor
}
'
if
cuda_version
:
=
os
.
environ
.
get
(
"
CUDA_VERSION
"
):
major
,
minor
,
*
_
=
cuda_version
.
split
(
"."
)
backend
=
f
"
cu
{
major
}{
minor
}
"
else
:
backend
=
'
cuda
'
backend
=
"
cuda
"
if
backend
:
exts
.
append
(
backend
)
if
_read_cmake_bool
(
os
.
environ
.
get
(
'
NO_GIT_VERSION
'
)):
if
_read_cmake_bool
(
os
.
environ
.
get
(
"
NO_GIT_VERSION
"
)):
pass
elif
git_hash
:
=
get_git_commit_id
():
exts
.
append
(
f
'
git
{
git_hash
[:
8
]
}
'
)
exts
.
append
(
f
"
git
{
git_hash
[:
8
]
}
"
)
else
:
exts
.
append
(
'
gitunknown
'
)
exts
.
append
(
"
gitunknown
"
)
if
exts
:
version
+=
'+'
+
'.'
.
join
(
exts
)
version
+=
"+"
+
"."
.
join
(
exts
)
return
version
...
...
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment