Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240514']
cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '2.0.1'
python-version: '3.12'
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '2.0.1'
cuda-version: '12.3.2'
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.14
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Log Built Wheels
run: |
ls dist
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
...@@ -28,6 +28,7 @@ venv ...@@ -28,6 +28,7 @@ venv
# benchmarks/ # benchmarks/
*.log *.log
test_results/*.csv
# tests/ # tests/
csrc/*/*.hip csrc/*/*.hip
...@@ -36,4 +37,4 @@ csrc/*/*hip.cuh ...@@ -36,4 +37,4 @@ csrc/*/*hip.cuh
csrc/*/*hip.cpp csrc/*/*hip.cpp
csrc/flash_attn/src/*.hip csrc/flash_attn/src/*.hip
csrc/flash_attn/src/*hip.h csrc/flash_attn/src/*hip.h
csrc/flash_attn/*hip.cpp csrc/flash_attn/*hip.cpp
\ No newline at end of file
cutlass @ 7d49e6c7
Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc
This diff is collapsed.
...@@ -386,6 +386,66 @@ struct Dropout { ...@@ -386,6 +386,66 @@ struct Dropout {
} }
} }
template <bool encode_dropout_in_sign_bit = false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_trans_dim64_opt(
Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride)
{
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr (encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (threadIdx.x / 64) * 16 + lane_id % 16;
extern __shared__ char smem_[];
uint8_t *p_rand_8 = reinterpret_cast<uint8_t *>(smem_ + 16384);
// write
int row_ = (threadIdx.x % 16) + (threadIdx.x / 64) * 16;
int col_ = (lane_id / 16) * 16;
// read
const int read_row = (lane_id / 16) * 4;
const int lane_group = (lane_id % 16) / 4;
const int lane_offset = lane_id % 4;
const int read_col = (threadIdx.x / 64) * 4 + lane_group * 16 + lane_offset;
// padding stride
// constexpr int RAND_STRIDE = 64 + 4;
constexpr int RAND_STRIDE = 64;
for (int i = 0; i < size<1>(tensor); ++i) {
const int row_idx_base = block_row_start + i * block_row_stride + (lane_id / 16) * 4;
uint2 rowcol = make_uint2(col_idx_offset, row_idx_base);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long &>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
*reinterpret_cast<uint4*>(&p_rand_8[row_ * RAND_STRIDE + col_]) = random_uint4;
// __syncthreads();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int j = 0; j < size<2>(tensor); ++j) {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int rand_read_row = read_row + j * 16 + mi;
const uint8_t t_rand = p_rand_8[(rand_read_row) * RAND_STRIDE + read_col];
tensor(mi, i, j) =
encode_dropout(t_rand <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
}
}
}; };
......
...@@ -206,6 +206,7 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -206,6 +206,7 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ mm_prefix_range_ptr; int * __restrict__ mm_prefix_range_ptr;
int max_mm_ranges = 0; int max_mm_ranges = 0;
bool use_alibi_sqrt = false; bool use_alibi_sqrt = false;
int se_balance_cnt;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
This diff is collapsed.
...@@ -80,6 +80,22 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_prefetch, bool Is_dropo ...@@ -80,6 +80,22 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_prefetch, bool Is_dropo
FLASH_UNSUPPORTED_ARCH FLASH_UNSUPPORTED_ARCH
#endif #endif
} }
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dv_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dv_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_mla_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_mla_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH) #if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
...@@ -357,6 +373,32 @@ void run_flash_bwd_separate_seqk_parallel_trans(Flash_bwd_params &params, cudaSt ...@@ -357,6 +373,32 @@ void run_flash_bwd_separate_seqk_parallel_trans(Flash_bwd_params &params, cudaSt
#endif #endif
} }
static inline int calc_se_balance_cnt(int b_h_num)
{
int se_balance_cnt = 1;
if(b_h_num % 13 == 0){
se_balance_cnt = 13;
} else if(b_h_num % 9 == 0){
se_balance_cnt = 9;
} else if(b_h_num % 8 == 0){
se_balance_cnt = 8;
} else if(b_h_num % 7 ==0){
se_balance_cnt = 7;
} else if(b_h_num % 6 ==0){
se_balance_cnt = 6;
} else if(b_h_num % 5 ==0){
se_balance_cnt = 5;
} else if(b_h_num % 4 ==0){
se_balance_cnt = 4;
} else if(b_h_num % 3 ==0){
se_balance_cnt = 3;
} else if(b_h_num % 2 ==0){
se_balance_cnt = 2;
} else {
se_balance_cnt = 1;
}
return se_balance_cnt;
}
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal> template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stream) { void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stream) {
...@@ -370,19 +412,32 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre ...@@ -370,19 +412,32 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre
const int num_n_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_n_block + 1 ) >> 1 : const int num_n_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_n_block + 1 ) >> 1 :
non_causal_num_n_block; non_causal_num_n_block;
const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_m_block = Is_causal ? (non_causal_num_m_block + 1 ) >> 1 : const int num_m_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_m_block + 1 ) >> 1 :
non_causal_num_m_block; non_causal_num_m_block;
#endif #endif
dim3 grid_m(num_m_block, params.h, params.b); dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h); dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
dim3 grid_n(num_n_block, params.h, params.b); dim3 grid_n(num_n_block, params.h, params.b);
if constexpr (Kernel_trans_traits::kHeadDim == 64 && Is_causal)
{
int b_h_num = params.b * params.h;
params.se_balance_cnt = calc_se_balance_cnt(b_h_num);
grid_n.x = params.se_balance_cnt;
grid_n.y = num_n_block;
grid_n.z = (params.h * params.b/params.se_balance_cnt);
grid_m.x = params.se_balance_cnt;
grid_m.y = num_m_block;
grid_m.z = (params.h * params.b/params.se_balance_cnt);
}
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params); flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0; const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dropout = Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN; constexpr int smem_size_dropout = Kernel_trans_traits::kHeadDim == 64 ? 4096 : Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize; constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128 || Kernel_trans_traits::kHeadDim == 64) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv); constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128 || Kernel_trans_traits::kHeadDim == 64) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv);
constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize; constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
...@@ -397,16 +452,38 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre ...@@ -397,16 +452,38 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// constexpr static bool Is_softcap = false; // constexpr static bool Is_softcap = false;
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] { BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
auto kernel = &flash_bwd_dk_dv_trans_16x64_prefetch< if constexpr (Kernel_trans_traits::kHeadDim == 256) {
Kernel_trans_traits, auto kernel = &flash_bwd_dv_trans_16x64_prefetch<
Is_dropout && !Is_softcap, Is_causal, Kernel_trans_traits,
Is_local && !Is_causal, Is_dropout && !Is_softcap, Is_causal,
Has_alibi, Is_local && !Is_causal,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, Has_alibi,
IsEvenKConst, IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
Is_softcap>; IsEvenKConst,
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params); Is_softcap>;
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
auto kernel_dk = &flash_bwd_dk_trans_16x64_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dk<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto kernel = &flash_bwd_dk_dv_trans_16x64_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}); });
auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel< auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel<
Kernel_traits, Kernel_traits,
...@@ -559,9 +636,9 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -559,9 +636,9 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// } // }
// // printf("max_smem_per_block = %d\n", max_smem_per_block); // // printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a")
{ {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/Is_dropout ? 64 : 128, /*kNWarps_*/4, T, 3>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/Is_dropout ? 128 : 128, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4, using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>; /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
...@@ -588,7 +665,7 @@ template<typename T, bool Is_causal> ...@@ -588,7 +665,7 @@ template<typename T, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4, using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, T, 3>; /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, T, 3>;
...@@ -617,7 +694,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -617,7 +694,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"){ if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a"){
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
// using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? 64 : 128, /*kBlockN_*/64, /*kNWarps_*/4, // using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? 64 : 128, /*kBlockN_*/64, /*kNWarps_*/4,
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? (Is_causal ? 64 : 128) : 128, /*kBlockN_*/64, /*kNWarps_*/4, using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? (Is_causal ? 64 : 128) : 128, /*kBlockN_*/64, /*kNWarps_*/4,
...@@ -686,7 +763,7 @@ void run_mha_bwd_hdim192_hdim128(Flash_bwd_params &params, cudaStream_t stream) ...@@ -686,7 +763,7 @@ void run_mha_bwd_hdim192_hdim128(Flash_bwd_params &params, cudaStream_t stream)
#if 1 #if 1
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>; // using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_mla_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_mla_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
...@@ -782,7 +859,7 @@ template<typename T, bool Is_causal> ...@@ -782,7 +859,7 @@ template<typename T, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// printf("%s:%d\n", __FILE__, __LINE__); // printf("%s:%d\n", __FILE__, __LINE__);
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>; using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
...@@ -810,7 +887,7 @@ template<typename T, bool Is_causal> ...@@ -810,7 +887,7 @@ template<typename T, bool Is_causal>
void run_mha_bwd_hdim512(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim512(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 512; constexpr static int Headdim = 512;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// printf("%s:%d\n", __FILE__, __LINE__); // printf("%s:%d\n", __FILE__, __LINE__);
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>; using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
......
This diff is collapsed.
...@@ -292,6 +292,7 @@ template<typename Kernel_traits, bool Is_causal> ...@@ -292,6 +292,7 @@ template<typename Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Is_Q_in_regs, "SplitKV implementation must support Is_Q_in_regs"); static_assert(Kernel_traits::Is_Q_in_regs, "SplitKV implementation must support Is_Q_in_regs");
static_assert(Kernel_traits::Share_Q_K_smem, "SplitKV implementation must support Share_Q_K_smem"); static_assert(Kernel_traits::Share_Q_K_smem, "SplitKV implementation must support Share_Q_K_smem");
params.num_splits = 1;
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的 // params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemQSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize; const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemQSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size); // printf("smem_size = %d\n", smem_size);
...@@ -317,33 +318,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -317,33 +318,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
}); });
}); });
// printf(" run_flash_splitkv_fwd params.num_splits = %d\n", params.num_splits);
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 32 : (Kernel_traits::kHeadDim % 64 == 0 ? 32 : 32);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
} }
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal> template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
...@@ -364,6 +338,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params, ...@@ -364,6 +338,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenMNConst = false;
constexpr static bool IsEvenKConst = true; constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false; // constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false; constexpr static bool Is_softcap = false;
...@@ -433,6 +408,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8(Flash_fwd_params &par ...@@ -433,6 +408,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8(Flash_fwd_params &par
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenMNConst = false;
constexpr static bool IsEvenKConst = true; constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false; // constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false; constexpr static bool Is_softcap = false;
...@@ -580,7 +556,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) ...@@ -580,7 +556,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
bool is_small = (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count); bool is_small = (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count);
if (params.is_vllm_kvcache) { if (params.is_vllm_kvcache) {
if constexpr(Headdim == 64) { if constexpr(Headdim == 64) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, T, 64>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, T, 64>;
if (is_small) { if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim64<64, 64, 64, 4, T, 1, 64>; using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim64<64, 64, 64, 4, T, 1, 64>;
...@@ -601,7 +577,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) ...@@ -601,7 +577,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
} }
} }
}else if constexpr(Headdim == 128) { }else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx936"||get_device_name() == "gfx938") { if (get_device_name() == "gfx936"||get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
assert(params.knew_ptr == nullptr && params.block_table != nullptr); assert(params.knew_ptr == nullptr && params.block_table != nullptr);
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>;
if (is_small) { if (is_small) {
...@@ -623,13 +599,13 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) ...@@ -623,13 +599,13 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
} }
} }
}else if constexpr(Headdim == 192) { }else if constexpr(Headdim == 192) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, T, 192>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, T, 192>;
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim192<192, 64, 64, 4, T, 3, 192>; using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim192<192, 64, 64, 4, T, 3, 192>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream); run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} }
}else if constexpr(Headdim == 256) { }else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim256<256, 64, 64, 4, T, 3, 256>; using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim256<256, 64, 64, 4, T, 3, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream); run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
...@@ -704,7 +680,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str ...@@ -704,7 +680,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
// printf("kBlockM = %d, kBlockN = %d", kBlockM, kBlockN); // printf("kBlockM = %d, kBlockN = %d", kBlockM, kBlockN);
#ifndef FLASHATTENTION_DISABLE_SPLITKV #ifndef FLASHATTENTION_DISABLE_SPLITKV
if constexpr(Headdim == 64) { if constexpr(Headdim == 64) {
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) { if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, TO, 64>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, TO, 64>;
if (params.seqlen_q < 64) { if (params.seqlen_q < 64) {
...@@ -717,7 +693,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str ...@@ -717,7 +693,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
} }
} }
}else if constexpr(Headdim == 128) { }else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) { if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, TO, 128>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, TO, 128>;
if (params.seqlen_q < 64) { if (params.seqlen_q < 64) {
...@@ -730,7 +706,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str ...@@ -730,7 +706,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
} }
} }
}else if constexpr(Headdim == 192) { }else if constexpr(Headdim == 192) {
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) { if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, TO, 192>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, TO, 192>;
if (params.seqlen_q < 64) { if (params.seqlen_q < 64) {
...@@ -743,7 +719,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str ...@@ -743,7 +719,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
} }
} }
}else if constexpr(Headdim == 256) { }else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) { if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, TO, 256>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, TO, 256>;
if (params.seqlen_q < 64) { if (params.seqlen_q < 64) {
...@@ -764,7 +740,7 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream) ...@@ -764,7 +740,7 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream)
constexpr static int kBlockM = 64; constexpr static int kBlockM = 64;
constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64); constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64);
if constexpr(Headdim == 256) { if constexpr(Headdim == 256) {
if (get_device_name() == "gfx938" || get_device_name() == "gfx936") { if (get_device_name() == "gfx938" || get_device_name() == "gfx936"|| get_device_name() == "gfx92a") {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits_dim256<256, 64, 64, 4, T, 3, 256>; using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits_dim256<256, 64, 64, 4, T, 3, 256>;
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
...@@ -799,7 +775,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -799,7 +775,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
#if 0 #if 0
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 256, 64, 4, false, /*Share_Q_K_smem_=*/true, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 256, 64, 4, false, /*Share_Q_K_smem_=*/true, T>, Is_dropout, Is_causal>(params, stream);
#else #else
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64; int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) { if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
...@@ -817,7 +793,7 @@ template<typename T, bool Is_causal> ...@@ -817,7 +793,7 @@ template<typename T, bool Is_causal>
void run_mha_fwd_padding_mask_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_padding_mask_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch_padding_mask<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 128, 64, 4, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch_padding_mask<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 128, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
} }
else { else {
...@@ -830,7 +806,7 @@ template<typename T, bool Is_causal> ...@@ -830,7 +806,7 @@ template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64; int mblocks = (params.seqlen_q + 64 - 1) / 64;
if(params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count){ if(params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count){
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim96<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim96<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
...@@ -849,7 +825,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -849,7 +825,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
// printf("run_mha_fwd_hdim128\n"); // printf("run_mha_fwd_hdim128\n");
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64; int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) { if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3, Is_skip_softmax>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3, Is_skip_softmax>, Is_dropout, Is_causal>(params, stream);
...@@ -869,7 +845,7 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -869,7 +845,7 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>; using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
// printf("run_mha_fwd_hdim128\n"); // printf("run_mha_fwd_hdim128\n");
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64; int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) { if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 64, 64, 4, T,T_out, 3>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 64, 64, 4, T,T_out, 3>, Is_dropout, Is_causal>(params, stream);
...@@ -904,7 +880,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -904,7 +880,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal> template<typename T, bool Is_causal>
void run_mha_fwd_hdim192_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_mla_traits</*Headdim*/192, 128, 64, 4, T, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_mla_traits</*Headdim*/192, 128, 64, 4, T, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream);
} }
else { else {
...@@ -919,7 +895,7 @@ void run_mha_fwd_hdim192_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stre ...@@ -919,7 +895,7 @@ void run_mha_fwd_hdim192_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stre
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>; static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>; using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_mla_traits_fp8</*Headdim*/192, 128, 64, 4, T,T_out, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_mla_traits_fp8</*Headdim*/192, 128, 64, 4, T,T_out, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream);
} }
else { else {
...@@ -941,8 +917,8 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -941,8 +917,8 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// constexpr static int Is_dropout = false; // constexpr static int Is_dropout = false;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim256<Headdim, 128, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
} else { } else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
} }
...@@ -954,7 +930,7 @@ void run_mha_fwd_hdim512(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -954,7 +930,7 @@ void run_mha_fwd_hdim512(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 512; constexpr static int Headdim = 512;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// constexpr static int Is_dropout = false; // constexpr static int Is_dropout = false;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
} else { } else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
......
...@@ -1854,8 +1854,8 @@ inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const ...@@ -1854,8 +1854,8 @@ inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const
for (int ni = 0; ni < size<2>(acc_o); ++ni) { for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16)*4 + ni * 32; col = (laneId / 16)*4 + ni * 32;
{ {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(0, mi, ni), 0, acc_o(1, mi, ni), 0); auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(0, mi, ni), acc_o(1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(2, mi, ni), 0, acc_o(3, mi, ni), 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(2, mi, ni), acc_o(3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0); auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1); auto res1 = reinterpret_cast<result_type const &>(d1);
...@@ -1868,8 +1868,8 @@ inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const ...@@ -1868,8 +1868,8 @@ inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const
} }
{ {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(4, mi, ni), 0, acc_o(5, mi, ni), 0); auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(4, mi, ni), acc_o(5, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(6, mi, ni), 0, acc_o(7, mi, ni), 0); auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(6, mi, ni), acc_o(7, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0); auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1); auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0]; gO(row, col) = res0[0];
...@@ -1934,7 +1934,7 @@ inline __device__ void compute_sparse_attn_sla_fp8(const Params &params) { ...@@ -1934,7 +1934,7 @@ inline __device__ void compute_sparse_attn_sla_fp8(const Params &params) {
const int bidb = blockIdx.z; const int bidb = blockIdx.z;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.y; const int bidh = blockIdx.y;
#if defined(__gfx938__) #if defined(__gfx938__) ||defined(__gfx92a__)
flash::sparse_attn_1rowblock_sla_fp8<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::sparse_attn_1rowblock_sla_fp8<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
#endif #endif
} }
......
...@@ -128,7 +128,7 @@ void run_mha_fwd_sparse_hdim64(Flash_fwd_params_sparse &params, cudaStream_t str ...@@ -128,7 +128,7 @@ void run_mha_fwd_sparse_hdim64(Flash_fwd_params_sparse &params, cudaStream_t str
template<typename T> template<typename T>
void run_mha_fwd_sparse_sla_hdim64(Flash_fwd_params_sparse &params, cudaStream_t stream) { void run_mha_fwd_sparse_sla_hdim64(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.seqlen_q <= 2048) if (params.seqlen_q <= 2048)
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>>(params, stream); run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>>(params, stream);
else else
...@@ -155,7 +155,7 @@ void run_mha_fwd_sparse_hdim128(Flash_fwd_params_sparse &params, cudaStream_t st ...@@ -155,7 +155,7 @@ void run_mha_fwd_sparse_hdim128(Flash_fwd_params_sparse &params, cudaStream_t st
template<typename T> template<typename T>
void run_mha_fwd_sparse_sla_hdim128(Flash_fwd_params_sparse &params, cudaStream_t stream) { void run_mha_fwd_sparse_sla_hdim128(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") { if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.seqlen_q <= 2048) if (params.seqlen_q <= 2048)
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3>>(params, stream); run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3>>(params, stream);
else else
...@@ -168,7 +168,7 @@ void run_mha_fwd_sparse_sla_hdim128_fp8(Flash_fwd_params_sparse &params, cudaStr ...@@ -168,7 +168,7 @@ void run_mha_fwd_sparse_sla_hdim128_fp8(Flash_fwd_params_sparse &params, cudaStr
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>; static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>; using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
if (get_device_name() == "gfx938") { if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// int num_blocks_64 = params.h * params.b * ((params.seqlen_q + 64 - 1) / 64);//3 // int num_blocks_64 = params.h * params.b * ((params.seqlen_q + 64 - 1) / 64);//3
// int num_blocks_128 = params.h * params.b * ((params.seqlen_q + 128 - 1) / 128);//2 // int num_blocks_128 = params.h * params.b * ((params.seqlen_q + 128 - 1) / 128);//2
// if ((num_blocks_64 <= sm_count || (num_blocks_128 / sm_count == 1 && num_blocks_128 % sm_count > 1 && (num_blocks_64 + sm_count - 1) / sm_count <= 3) || force_blockm64) && !force_blockm128) { // if ((num_blocks_64 <= sm_count || (num_blocks_128 / sm_count == 1 && num_blocks_128 % sm_count > 1 && (num_blocks_64 + sm_count - 1) / sm_count <= 3) || force_blockm64) && !force_blockm128) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment