Unverified Commit 8e1b88f3 authored by alex_xiao's avatar alex_xiao Committed by GitHub
Browse files

[CI][AMD] Add AMD GPU CI and fix some related bugs (#694)



* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668)

- Enhanced buffer index handling to address precision issues by removing redundant operations.
- Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection.
- Updated related documentation to reflect changes in buffer management practices.

* Remove obsolete test script for AMD example, streamlining the examples directory.

* Remove unused dtype_size variable in AMD example script to streamline code.

* Add input configuration file and update AMD example script for enhanced flexibility

- Introduced a new input.txt file for configurable parameters.
- Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack.
- Streamlined the main function for better clarity and organization.
- Added a new test script to facilitate running the example with specified parameters.

* Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations

- Deleted input.txt and test.sh files as they are no longer needed.
- Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance.
- Reintroduced swizzle usage in the kernel for better performance.

* Refactor AMD example script for FlashAttention-2

- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`.
- Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls.
- Removed outdated comments and improved code organization for better readability.

* Refactor formatting in AMD FlashAttention example script

- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function.
- Streamlined the `main` function parameter formatting for consistency.
- Removed unnecessary blank lines to enhance overall code organization.

* Update example_amd_flash_attn_fwd.py

* Update AMD FlashAttention example and TVM submodule

- Added a new example script `example_amd_flash_attn_fwd_k_block.py` for FlashAttention with K-blocking support.
- Enhanced `example_amd_flash_attn_fwd.py` by expanding configuration options for block sizes and threads.
- Updated the TVM submodule to the latest commit for improved functionality.
- Introduced a new test script `test.sh` to facilitate running the new example with specified parameters.

* Add CI workflow for automated format checking and testing

- Introduced a new GitHub Actions workflow in `amd_ci.yml` to automate format checks and testing for pull requests.
- The workflow includes steps for setting up a Python environment, running format checks, and executing tests.
- Removed obsolete example script `example_amd_flash_attn_fwd_k_block.py` and test script `test.sh` to streamline the examples directory.

* Rename CI workflow from "CI" to "AMD CI" for clarity and specificity.

* Update AMD CI workflow to include copying PyTorch, TorchVision, and Torchaudio packages to the virtual environment for improved dependency management.

* Update AMD CI workflow to install pytest directly instead of using requirements-test.txt

* Update AMD CI workflow to remove 'flash-attn' from requirements and install dependencies from requirements-test.txt

* Refactor AMD CI workflow to enhance clarity in removing 'flash-attn' from requirements-test.txt before installation

* Remove Torchaudio package copying from AMD CI workflow to streamline dependency management.

* Refactor AMD CI workflow to remove the format-check job and streamline the build-test process by directly copying PyTorch and TorchVision packages to the virtual environment.

* Add installation of ROCm in AMD CI workflow

- Included a step to execute the `install_rocm.sh` script for improved setup.
- Removed unnecessary blank line for better readability in the workflow script.

* Remove installation step for ROCm in AMD CI workflow to simplify the setup process.

* Update AMD CI workflow to run specific test file with verbose output instead of all tests.

* Add new tilelang built-in operations for AMD architecture

- Introduced `tvm_mfma`, `tvm_mfma_store`, `tvm_rdna_wmma`, and `tvm_rdna_wmma_store` built-in operations to enhance support for matrix multiplication and storage in tilelang.
- Each operation is configured with the appropriate number of inputs and marked as opaque in terms of call effects.

* Enhance autotuner configurations and GEMM operations in AMD example

- Updated block sizes and num_split_q parameters in `get_configs` for improved autotuning.
- Modified `T.gemm` calls in `fast_flashattn` to utilize `GemmWarpPolicy.FullRow`, optimizing performance for matrix multiplications.

* Update autotuner configurations in AMD example for enhanced performance

- Refined block sizes, thread counts, and added new parameters in `get_configs` to optimize autotuning.
- Adjusted `fast_flashattn` function to incorporate new parameters for panel size and coalesced widths, improving memory access patterns.

* Enhance autotuner configurations and memory handling in AMD example

- Expanded block sizes and thread counts in `get_configs` for improved autotuning capabilities.
- Updated `fast_flashattn` to utilize a new shared memory allocation strategy, optimizing memory access patterns during GEMM operations.

* Refine autotuner configurations and memory usage in AMD example

- Reduced block sizes and adjusted thread counts in `get_configs` for optimized autotuning.
- Updated `fast_flashattn` to utilize register fragments for accumulation, minimizing LDS usage and enhancing performance during GEMM operations.

* Update autotuner configurations in AMD example for enhanced performance

- Expanded block sizes and thread counts in `get_configs` to improve autotuning capabilities.
- Adjusted `num_split_q` and `v_coalesced_width` parameters for better optimization during GEMM operations.

* Enhance autotuner configurations and GEMM operations in AMD example

- Expanded thread counts in `get_configs` to include higher values for improved autotuning.
- Updated `fast_flashattn` to adjust accumulation logic and ensure proper handling of causal conditions, optimizing performance during matrix multiplications.

* Update AMD CI workflow and remove obsolete test script

- Modified the CI workflow to run on multiple environments: self-hosted, amd, and gpu.
- Deleted the outdated `test.sh` script from the examples directory, streamlining the project structure.

* Remove TVM subproject from 3rdparty directory

* Refactor configuration generation and accumulation logic in AMD example

- Reformatted the `get_configs` function for improved readability by aligning parameters.
- Adjusted the `fast_flashattn` function to enhance clarity in the conditional logic for accumulation, ensuring better handling of causal conditions.

* Enhance AMD CI workflow with additional logging and setup steps

- Added echo statements to provide feedback during the CI process, indicating when the environment is running on an AMD GPU, copying necessary packages, and installing requirements.
- Improved clarity in the workflow by explicitly stating when the project is being installed and when tests are being executed.

* Comment out package copying in AMD CI workflow to prevent potential issues during environment setup

* Update AMD CI workflow to install nightly versions of PyTorch and remove obsolete package copying steps

* Enhance BuildTileLangHIP function by adding whitespace for improved readability

* Refactor kTVMGridConstant definition for clarity and remove unnecessary comment

* Update TVM subproject to latest commit a64a5926a6e59f5417ef2501f9d88b467337cf6a

* lint fix

* Update AMD CI workflow to use requirements-rocm.txt for dependency installation

* fix ci

* Remove dependency on format-check from AMD CI workflow

* fix ci

* fix ci

* fix ci

* Remove format-check job from AMD CI workflow

* Add torch to requirements-rocm.txt and remove explicit pip install commands from AMD CI workflow

* Add dependency on format-check job in AMD CI workflow

* Add format-check job to AMD CI workflow

* Update format-check job in AMD CI workflow to run on self-hosted environment

* Enhance format-check job in AMD CI workflow with improved Python environment setup and automatic commit of lint changes

* Update amd_ci.yml

---------
Co-authored-by: default avatarxinxyxiao <xinyxiao@amd.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent d0742860
name: CI Test on AMD
on: [pull_request]
env:
PYTHON_VERSION: '3.12'
VENV_DIR: tilelang_ci
PYTORCH_INDEX_URL: https://download.pytorch.org/whl/nightly/rocm6.3/
jobs:
format-check:
runs-on: [self-hosted, amd, gpu]
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Ensure venv (local & persistent)
run: |
set -e
REQS_HASH=$(sha256sum requirements-test.txt 2>/dev/null | awk '{print $1}' || echo "no_requirements")
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
# shellcheck source=/dev/null
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install flash_attn==2.5.8 --no-user --no-build-isolation
touch "$MARKER"
fi
- name: Run format check
run: |
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
if ! output=$(./format.sh 2>&1); then
echo "------------------------------------"
echo "message:"
echo "$output"
printf '%s\n' "$output" | grep "Please review and stage the changes."
echo "------------------------------------"
exit 1
fi
- name: Commit and Push Changes
uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: "lint"
build-test-amd:
runs-on: [self-hosted, amd, gpu]
needs: format-check
permissions:
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Ensure venv (local & persistent)
run: |
echo "Running on AMD GPU"
set -e
REQS_HASH=$(sha256sum requirements-rocm.txt | cut -d ' ' -f 1)
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
echo "Installing requirements"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
if [[ -f requirements-rocm.txt ]]; then
pip install --pre torch torchvision torchaudio --index-url ${{ env.PYTORCH_INDEX_URL }}
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-rocm.txt
fi
USE_ROCM=True pip install . --no-user
touch "$MARKER"
fi
- name: Install project (wheel form)
run: |
echo "Installing project (wheel form)"
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
USE_ROCM=True pip install . --no-user
- name: Run tests
run: |
echo "Running tests"
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd testing/python/amd
unset PYTHONPATH
python -m pytest -v test_tilelang_test_amd.py
Subproject commit 5a433cc1af4a6d859cdf2b62c7c5ab28bf5836ea
Subproject commit a64a5926a6e59f5417ef2501f9d88b467337cf6a
......@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
......@@ -29,18 +30,24 @@ def ref_program(Q, K, V, is_causal, groups=1):
def get_configs():
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [64, 128, 256]
block_N = [32, 64, 128]
threads = [128, 256, 512]
num_split_q = [32, 64, 128]
num_stages = [0, 1, 2]
enable_rasterization = [True, False]
k_pack = [1, 2]
block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256]
threads = [64, 128, 192, 256, 512, 1024]
num_split_q = [32, 64, 128, 256, 256]
num_stages = [0]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8, 9, 10]
qk_coalesced_width = [8]
v_coalesced_width = [4]
valid_configs = []
for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads,
num_stages, enable_rasterization, k_pack):
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
......@@ -48,7 +55,10 @@ def get_configs():
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
valid_configs.append({
'block_M': 64,
......@@ -57,7 +67,10 @@ def get_configs():
'threads': 256,
'num_stages': 1,
'enable_rasterization': True,
'k_pack': 2
'k_pack': 2,
'panel_size': 64,
'qk_coalesced_width': 8,
'v_coalesced_width': 8,
})
return valid_configs
......@@ -78,6 +91,9 @@ def fast_flashattn(
num_stages: int,
enable_rasterization: bool,
k_pack: int,
panel_size: int,
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5 * 1.44269504
head_kv = heads // groups
......@@ -86,8 +102,8 @@ def fast_flashattn(
dtype = "float16"
accum_dtype = "float"
v_vec_size = 4
vec_size = 4 * k_pack
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
......@@ -97,7 +113,7 @@ def fast_flashattn(
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(10, enable=enable_rasterization)
T.use_swizzle(panel_size, enable=enable_rasterization)
bz = byz_combined // heads
by = byz_combined % heads
......@@ -105,9 +121,9 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var("int32")
bx[0] = b_split
bx = b_split
with T.While(bx[0] < num_q_blocks):
with T.While(bx < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
......@@ -115,13 +131,14 @@ def fast_flashattn(
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)
current_bx = bx[0]
current_bx = bx
q_block_offset = current_bx * block_M
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
P_shared = T.alloc_shared([block_M, block_N], dtype)
# Use register fragment for P instead of shared memory to reduce LDS usage
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
......@@ -135,6 +152,8 @@ def fast_flashattn(
loop_end_k = T.ceildiv(q_block_offset + block_M,
block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
......@@ -147,13 +166,20 @@ def fast_flashattn(
V_shared,
coalesced_width=v_vec_size)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j,
acc_s[i, j], -T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
k_pack=k_pack,
policy=GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
......@@ -169,15 +195,14 @@ def fast_flashattn(
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
row_sum = T.alloc_fragment([block_M], accum_dtype)
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]
T.copy(acc_s, P_shared)
T.sync_threads()
# Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V
T.copy(acc_s, acc_s_cast)
T.gemm(P_shared, V_shared, acc_o)
T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
......@@ -187,7 +212,7 @@ def fast_flashattn(
for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
bx[0] = current_bx + num_split_q
bx = current_bx + num_split_q
return main
......
# lint requirements
-r requirements-lint.txt
# build requirements
Cython
cmake>=3.26
# runtime requirements
cffi
cpplint
Cython
docutils
dtlib
numpy>=1.23.5
pytest>=6.2.4
pytest_xdist>=2.2.1
packaging>=21.0
PyYAML
tqdm>=4.62.3
typing_extensions>=4.10.0
requests
cloudpickle
ml_dtypes
psutil
torch
tabulate
wheel
setuptools
einops
scipy
tornado
......@@ -141,6 +141,24 @@ TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tvm_mfma).set_num_inputs(12).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tvm_mfma_store)
.set_num_inputs(6)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma)
.set_num_inputs(12)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma_store)
.set_num_inputs(6)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -4,7 +4,7 @@
#include "codegen_hip.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
......@@ -882,7 +882,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(builtin::tvm_mfma())) {
} else if (op->op.same_as(tl::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
......
......@@ -8,6 +8,11 @@
#include "codegen_hip.h"
#include "runtime/rocm/rocm_module.h"
#include <tvm/ffi/function.h>
#ifndef kTVMGridConstant
#define kTVMGridConstant 130
#endif
namespace tvm {
namespace codegen {
......@@ -44,7 +49,6 @@ ExtractFuncInfo(const IRModule &mod) {
}
runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
......@@ -59,23 +63,28 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string();
// Use the new FFI API to get registered functions
using ffi::Function;
if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto *f = Registry::Get("tilelang_callback_hip_compile")) {
ptx = (*f)(code, target).operator std::string();
if (auto f = Function::GetGlobal("tilelang_callback_hip_compile")) {
ptx = (*f)(code, target).cast<std::string>();
if (ptx[0] != '/')
fmt = "hsaco";
} else {
ICHECK(false) << "tilelang_callback_hip_compile is not set";
}
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}
runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
......@@ -90,12 +99,17 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string();
// Use the new FFI API to get registered functions
using ffi::Function;
if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code,
std::string());
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
......@@ -105,4 +119,4 @@ TVM_FFI_STATIC_INIT_BLOCK({
});
} // namespace codegen
} // namespace tvm
} // namespace tvm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment