"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5588725e8e7be497839432e5328c596169385f16"
Unverified Commit 53be59dc authored by danielhua23's avatar danielhua23 Committed by GitHub
Browse files

[AMD] Enable FA2 fwd on AMD MI300X (#1406)

* enable FA2 on AMD MI300X

* make lint happy
parent 0eb33f28
...@@ -9,25 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive ...@@ -9,25 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git wget \ build-essential git wget \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
ENV PATH="/opt/conda/bin:${PATH}" ENV PATH="/opt/conda/bin:${PATH}"
ENV LIBGL_ALWAYS_INDIRECT=1 ENV LIBGL_ALWAYS_INDIRECT=1
ENV USE_ROCM=1
ENV USE_CUDA=0
ENV ROCM_HOME=/opt/rocm
ENV HIP_PLATFORM=amd
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN conda run -n py_3.10 conda install pip cmake -y && \ RUN conda run -n py_3.10 conda install pip cmake -y && \
conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \ conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \
conda clean --all conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \
apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \ # Copy local tilelang directory instead of cloning from git
mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \ # Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest .
conda run -n py_3.10 bash -c "pip install 'numpy<2.0' --force-reinstall && cd tilelang && USE_ROCM=1 pip install -e . -v" COPY . /root/tilelang
RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \
conda run -n py_3.10 bash -c "cd /root/tilelang && \
# Backup and modify pyproject.toml to remove torch from dependencies \
cp pyproject.toml pyproject.toml.bak && \
sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \
# Install tilelang with all dependencies except torch \
USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \
# Restore original pyproject.toml \
mv pyproject.toml.bak pyproject.toml"
RUN conda init bash && \ RUN conda init bash && \
echo "conda activate py_3.10" >> /root/.bashrc echo "conda activate py_3.10" >> /root/.bashrc
SHELL ["/bin/bash", "-l", "-c"] SHELL ["/bin/bash", "-l", "-c"]
ENTRYPOINT ["/bin/bash", "--login", "-i"] ENTRYPOINT ["/bin/bash", "--login", "-i"]
\ No newline at end of file
...@@ -8,6 +8,21 @@ import argparse ...@@ -8,6 +8,21 @@ import argparse
from functools import partial from functools import partial
# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = []
for param in params:
if hasattr(param, 'shape') and hasattr(param, 'dtype'):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device='cuda')
tensors.append(tensor)
else:
tensors.append(param)
return tensors
def ref_program(Q, K, V, is_causal, groups=1): def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size( assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
...@@ -63,7 +78,7 @@ def get_configs(): ...@@ -63,7 +78,7 @@ def get_configs():
return valid_configs return valid_configs
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True) @tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
@tilelang.jit(out_idx=[3]) @tilelang.jit(out_idx=[3])
def fast_flashattn( def fast_flashattn(
batch, batch,
......
...@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int") ...@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)) Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int") .set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", pow_of_int_op)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op); .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
PrimExpr infinity_op(PrimExpr args) { PrimExpr infinity_op(PrimExpr args) {
...@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity") ...@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)) Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity") .set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op); .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -1190,9 +1190,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -1190,9 +1190,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
if (op->value < 0) { if (op->value < 0) {
temp << "-"; temp << "-";
} }
temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF"); temp << ((op->dtype.bits() == 32) ? "HUGE_VALF" : "HUGE_VAL");
} else if (std::isnan(op->value)) { } else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN"); temp << ((op->dtype.bits() == 32) ? "NAN" : "NAN");
} else { } else {
temp << std::scientific << op->value; temp << std::scientific << op->value;
if (op->dtype.bits() == 32) if (op->dtype.bits() == 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment