Unverified Commit 53be59dc authored by danielhua23's avatar danielhua23 Committed by GitHub
Browse files

[AMD] Enable FA2 fwd on AMD MI300X (#1406)

* enable FA2 on AMD MI300X

* make lint happy
parent 0eb33f28
......@@ -9,25 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git wget \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
ENV PATH="/opt/conda/bin:${PATH}"
ENV LIBGL_ALWAYS_INDIRECT=1
ENV USE_ROCM=1
ENV USE_CUDA=0
ENV ROCM_HOME=/opt/rocm
ENV HIP_PLATFORM=amd
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN conda run -n py_3.10 conda install pip cmake -y && \
conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \
conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \
apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "pip install 'numpy<2.0' --force-reinstall && cd tilelang && USE_ROCM=1 pip install -e . -v"
# Copy local tilelang directory instead of cloning from git
# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest .
COPY . /root/tilelang
RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \
conda run -n py_3.10 bash -c "cd /root/tilelang && \
# Backup and modify pyproject.toml to remove torch from dependencies \
cp pyproject.toml pyproject.toml.bak && \
sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \
# Install tilelang with all dependencies except torch \
USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \
# Restore original pyproject.toml \
mv pyproject.toml.bak pyproject.toml"
RUN conda init bash && \
echo "conda activate py_3.10" >> /root/.bashrc
SHELL ["/bin/bash", "-l", "-c"]
ENTRYPOINT ["/bin/bash", "--login", "-i"]
\ No newline at end of file
ENTRYPOINT ["/bin/bash", "--login", "-i"]
......@@ -8,6 +8,21 @@ import argparse
from functools import partial
# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = []
for param in params:
if hasattr(param, 'shape') and hasattr(param, 'dtype'):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device='cuda')
tensors.append(tensor)
else:
tensors.append(param)
return tensors
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
......@@ -63,7 +78,7 @@ def get_configs():
return valid_configs
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
......
......@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", pow_of_int_op)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
PrimExpr infinity_op(PrimExpr args) {
......@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op);
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);
} // namespace tl
} // namespace tvm
......@@ -1190,9 +1190,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
if (op->value < 0) {
temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
temp << ((op->dtype.bits() == 32) ? "HUGE_VALF" : "HUGE_VAL");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
temp << ((op->dtype.bits() == 32) ? "NAN" : "NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment