Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
b10d49b2
Unverified
Commit
b10d49b2
authored
Nov 13, 2025
by
Jiaxing Ding
Committed by
GitHub
Nov 13, 2025
Browse files
[AMD] enable amd ci test & fix bug & fix dockerfile (#1244)
parent
468b1b70
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
52 additions
and
49 deletions
+52
-49
.github/workflows/ci.yml
.github/workflows/ci.yml
+1
-1
docker/Dockerfile.cu118
docker/Dockerfile.cu118
+1
-1
docker/Dockerfile.cu120
docker/Dockerfile.cu120
+1
-1
docker/Dockerfile.cu121
docker/Dockerfile.cu121
+1
-1
docker/Dockerfile.cu123
docker/Dockerfile.cu123
+1
-1
docker/Dockerfile.cu124
docker/Dockerfile.cu124
+1
-1
docker/Dockerfile.cu125
docker/Dockerfile.cu125
+1
-1
docker/Dockerfile.cu126
docker/Dockerfile.cu126
+1
-1
docker/Dockerfile.cu128
docker/Dockerfile.cu128
+1
-1
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+1
-1
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+3
-9
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
+21
-3
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+18
-27
No files found.
.github/workflows/ci.yml
View file @
b10d49b2
...
@@ -379,7 +379,7 @@ jobs:
...
@@ -379,7 +379,7 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
./python/amd
/test_tilelang_test_amd.py
./python/amd
# Apple Metal tests
# Apple Metal tests
-
name
:
Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
-
name
:
Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
...
...
docker/Dockerfile.cu118
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu120
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu121
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu123
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu124
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu125
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu126
View file @
b10d49b2
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
...
@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
./install_cuda.sh
&& cd TileLang &&
USE_CUDA=1 pip install -e . -v
CMD bash
CMD bash
docker/Dockerfile.cu128
View file @
b10d49b2
...
@@ -26,6 +26,6 @@ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev z
...
@@ -26,6 +26,6 @@ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev z
RUN pip install cython
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang &&
cmake -S . -B build -DUSE_CUDA=ON && cmake --build build
-
j
&& cd TileLang &&
USE_CUDA=1 pip install -e .
-
v
CMD bash
CMD bash
docker/Dockerfile.rocm
View file @
b10d49b2
...
@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
...
@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang &&
./install_rocm.sh
"
conda run -n py_3.10 bash -c "cd tilelang &&
USE_ROCM=1 pip install -e . -v
"
RUN conda init bash
RUN conda init bash
...
...
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
b10d49b2
...
@@ -22,15 +22,6 @@ def tl_matmul(
...
@@ -22,15 +22,6 @@ def tl_matmul(
b_transposed
=
True
,
b_transposed
=
True
,
k_pack
=
1
,
k_pack
=
1
,
):
):
assert
in_dtype
in
[
"float16"
,
"int8"
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
...
@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
...
@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
if
in_dtype
==
"int8"
:
if
in_dtype
==
"int8"
:
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
elif
in_dtype
==
"float8_e4m3fnuz"
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
else
:
else
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
...
...
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
View file @
b10d49b2
...
@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
...
@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
if
in_dtype
==
"int8"
:
if
in_dtype
==
"int8"
:
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
elif
in_dtype
==
"float8_e4m3fnuz"
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
else
:
else
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
...
@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
...
@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
@
tilelang
.
testing
.
requires_rocm
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
...
@@ -283,6 +286,21 @@ def test_assert_tl_matmul():
...
@@ -283,6 +286,21 @@ def test_assert_tl_matmul():
k_pack
=
2
,
k_pack
=
2
,
b_preshuffle
=
True
)
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_transposed
=
False
,
b_preshuffle
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
tilelang/intrinsics/mfma_macro_generator.py
View file @
b10d49b2
...
@@ -374,8 +374,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -374,8 +374,6 @@ class MatrixCoreIntrinEmitter:
a_local_stride
:
PrimExpr
=
k_inner
*
warp_rows
*
local_size_a
if
a_is_fragment
else
0
a_local_stride
:
PrimExpr
=
k_inner
*
warp_rows
*
local_size_a
if
a_is_fragment
else
0
b_local_stride
:
PrimExpr
=
k_inner
*
warp_cols
*
local_size_b
if
b_is_fragment
else
0
b_local_stride
:
PrimExpr
=
k_inner
*
warp_cols
*
local_size_b
if
b_is_fragment
else
0
print
(
a_local_stride
,
b_local_stride
)
@
T
.
macro
@
T
.
macro
def
_warp_mfma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
_warp_mfma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
for
kp
,
i
,
j
in
T
.
grid
(
k_pack
,
warp_rows
,
warp_cols
):
for
kp
,
i
,
j
in
T
.
grid
(
k_pack
,
warp_rows
,
warp_cols
):
...
@@ -678,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
...
@@ -678,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
is_m_first
:
bool
|
None
=
False
,
is_m_first
:
bool
|
None
=
False
,
a_preshuffle
:
bool
|
None
=
False
,
a_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
):
super
().
__init__
(
self
.
a_dtype
=
a_dtype
a_dtype
=
a_dtype
,
self
.
b_dtype
=
b_dtype
b_dtype
=
b_dtype
,
self
.
accum_dtype
=
accum_dtype
accum_dtype
=
accum_dtype
,
self
.
a_transposed
=
a_transposed
a_transposed
=
a_transposed
,
self
.
b_transposed
=
b_transposed
b_transposed
=
b_transposed
,
# Hint Information
block_row_warps
=
block_row_warps
,
self
.
block_row_warps
=
block_row_warps
block_col_warps
=
block_col_warps
,
self
.
block_col_warps
=
block_col_warps
warp_row_tiles
=
warp_row_tiles
,
self
.
warp_row_tiles
=
warp_row_tiles
warp_col_tiles
=
warp_col_tiles
,
self
.
warp_col_tiles
=
warp_col_tiles
chunk
=
chunk
,
self
.
chunk
=
chunk
reduce_k
=
reduce_k
,
self
.
_initialize_k_dim
(
a_dtype
)
num_elems_per_byte
=
num_elems_per_byte
,
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
k_pack
=
k_pack
,
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
,
self
.
WARP_SIZE
)
is_m_first
=
is_m_first
,
self
.
_initialize_mfma_prefix
(
self
.
k_dim
)
thread_var
=
thread_var
,
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
_initialize_preshuffle
(
a_preshuffle
,
b_preshuffle
)
self
.
_initialize_preshuffle
(
a_preshuffle
,
b_preshuffle
)
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_preshuffle
(
self
,
a_preshuffle
:
bool
,
b_preshuffle
:
bool
):
def
_initialize_preshuffle
(
self
,
a_preshuffle
:
bool
,
b_preshuffle
:
bool
):
if
a_preshuffle
is
not
None
:
if
a_preshuffle
is
not
None
:
self
.
a_preshuffle
=
a_preshuffle
self
.
a_preshuffle
=
a_preshuffle
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment