Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
53be59dc
Unverified
Commit
53be59dc
authored
Dec 11, 2025
by
danielhua23
Committed by
GitHub
Dec 11, 2025
Browse files
[AMD] Enable FA2 fwd on AMD MI300X (#1406)
* enable FA2 on AMD MI300X * make lint happy
parent
0eb33f28
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
9 deletions
+44
-9
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+23
-5
examples/amd/example_amd_flash_attn_fwd.py
examples/amd/example_amd_flash_attn_fwd.py
+16
-1
src/op/math.cc
src/op/math.cc
+3
-1
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+2
-2
No files found.
docker/Dockerfile.rocm
View file @
53be59dc
...
@@ -9,21 +9,39 @@ ENV DEBIAN_FRONTEND=noninteractive
...
@@ -9,21 +9,39 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git wget \
build-essential git wget \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
ENV PATH="/opt/conda/bin:${PATH}"
ENV PATH="/opt/conda/bin:${PATH}"
ENV LIBGL_ALWAYS_INDIRECT=1
ENV LIBGL_ALWAYS_INDIRECT=1
ENV USE_ROCM=1
ENV USE_CUDA=0
ENV ROCM_HOME=/opt/rocm
ENV HIP_PLATFORM=amd
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN conda run -n py_3.10 conda install pip cmake -y && \
conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \
conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \
conda clean --all
conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \
apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
# Copy local tilelang directory instead of cloning from git
conda run -n py_3.10 bash -c "pip install 'numpy<2.0' --force-reinstall && cd tilelang && USE_ROCM=1 pip install -e . -v"
# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest .
COPY . /root/tilelang
RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \
conda run -n py_3.10 bash -c "cd /root/tilelang && \
# Backup and modify pyproject.toml to remove torch from dependencies \
cp pyproject.toml pyproject.toml.bak && \
sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \
# Install tilelang with all dependencies except torch \
USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \
# Restore original pyproject.toml \
mv pyproject.toml.bak pyproject.toml"
RUN conda init bash && \
RUN conda init bash && \
echo "conda activate py_3.10" >> /root/.bashrc
echo "conda activate py_3.10" >> /root/.bashrc
...
...
examples/amd/example_amd_flash_attn_fwd.py
View file @
53be59dc
...
@@ -8,6 +8,21 @@ import argparse
...
@@ -8,6 +8,21 @@ import argparse
from
functools
import
partial
from
functools
import
partial
# Custom supply function to ensure tensors are created on GPU
def
supply_tensors_gpu
(
params
):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors
=
[]
for
param
in
params
:
if
hasattr
(
param
,
'shape'
)
and
hasattr
(
param
,
'dtype'
):
# Force creation on GPU device
shape
=
[
int
(
s
)
for
s
in
param
.
shape
]
tensor
=
torch
.
randn
(
shape
,
dtype
=
param
.
dtype
,
device
=
'cuda'
)
tensors
.
append
(
tensor
)
else
:
tensors
.
append
(
param
)
return
tensors
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
def
ref_program
(
Q
,
K
,
V
,
is_causal
,
groups
=
1
):
assert
Q
.
size
(
assert
Q
.
size
(
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
2
)
==
K
.
size
(
2
)
*
groups
,
f
"Q heads
{
Q
.
size
(
2
)
}
K heads
{
K
.
size
(
2
)
}
groups
{
groups
}
"
...
@@ -63,7 +78,7 @@ def get_configs():
...
@@ -63,7 +78,7 @@ def get_configs():
return
valid_configs
return
valid_configs
@
tilelang
.
autotune
(
configs
=
get_configs
(),
cache_input_tensors
=
True
)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
cache_input_tensors
=
True
,
supply_prog
=
supply_tensors_gpu
)
@
tilelang
.
jit
(
out_idx
=
[
3
])
@
tilelang
.
jit
(
out_idx
=
[
3
])
def
fast_flashattn
(
def
fast_flashattn
(
batch
,
batch
,
...
...
src/op/math.cc
View file @
53be59dc
...
@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
...
@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
))
Integer
(
CallEffectKind
::
kPure
))
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"pow_of_int"
)
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"pow_of_int"
)
.
set_attr
<
FLowerIntrinsic
>
(
"hip.FLowerIntrinsic"
,
pow_of_int_op
)
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
pow_of_int_op
);
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
pow_of_int_op
);
PrimExpr
infinity_op
(
PrimExpr
args
)
{
PrimExpr
infinity_op
(
PrimExpr
args
)
{
...
@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
...
@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
))
Integer
(
CallEffectKind
::
kPure
))
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"infinity"
)
.
set_attr
<
TScriptPrinterName
>
(
"TScriptPrinterName"
,
"infinity"
)
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
infinity_op
);
.
set_attr
<
FLowerIntrinsic
>
(
"cuda.FLowerIntrinsic"
,
infinity_op
)
.
set_attr
<
FLowerIntrinsic
>
(
"hip.FLowerIntrinsic"
,
infinity_op
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/target/codegen_hip.cc
View file @
53be59dc
...
@@ -1190,9 +1190,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
...
@@ -1190,9 +1190,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
if
(
op
->
value
<
0
)
{
if
(
op
->
value
<
0
)
{
temp
<<
"-"
;
temp
<<
"-"
;
}
}
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"H
IPRT_INF_F"
:
"HIPRT_INF
"
);
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"H
UGE_VALF"
:
"HUGE_VAL
"
);
}
else
if
(
std
::
isnan
(
op
->
value
))
{
}
else
if
(
std
::
isnan
(
op
->
value
))
{
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"
HIPRT_
NAN
_F
"
:
"
HIPRT_
NAN"
);
temp
<<
((
op
->
dtype
.
bits
()
==
32
)
?
"NAN"
:
"NAN"
);
}
else
{
}
else
{
temp
<<
std
::
scientific
<<
op
->
value
;
temp
<<
std
::
scientific
<<
op
->
value
;
if
(
op
->
dtype
.
bits
()
==
32
)
if
(
op
->
dtype
.
bits
()
==
32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment