Commit 518a5f4d authored by hly's avatar hly
Browse files

import aicc-master-dev

parent c2a1b310
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240514']
cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '2.0.1'
python-version: '3.12'
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '2.0.1'
cuda-version: '12.3.2'
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.14
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Log Built Wheels
run: |
ls dist
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
......@@ -28,6 +28,7 @@ venv
# benchmarks/
*.log
test_results/*.csv
# tests/
csrc/*/*.hip
......
cutlass @ 7d49e6c7
Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc
......@@ -16,6 +16,7 @@
#include "static_switch.h"
#ifdef HAS_HG_DISPATCH
#include <utility>
#include <vector>
// Symbols defined in libflash_attention.so (HG), linked when HAS_HG_DISPATCH is set.
std::vector<at::Tensor>
......@@ -31,7 +32,11 @@ hg_fwd_bshd(at::Tensor &q,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> q_descale_,
c10::optional<at::Tensor> k_descale_,
c10::optional<at::Tensor> v_descale_,
const bool is_bf16_output);
std::vector<at::Tensor>
hg_fwd_bhsd(at::Tensor &q,
......@@ -46,12 +51,16 @@ hg_fwd_bhsd(at::Tensor &q,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> q_descale_,
c10::optional<at::Tensor> k_descale_,
c10::optional<at::Tensor> v_descale_,
const bool is_bf16_output);
std::vector<at::Tensor>
hg_varlen_fwd_bshd(const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
hg_varlen_fwd_bshd(at::Tensor &q,
at::Tensor &k,
at::Tensor &v,
c10::optional<at::Tensor> &out_,
const at::Tensor &cu_seqlens_q,
const at::Tensor &cu_seqlens_k,
......@@ -67,7 +76,11 @@ hg_varlen_fwd_bshd(const at::Tensor &q,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> q_descale_,
c10::optional<at::Tensor> k_descale_,
c10::optional<at::Tensor> v_descale_,
const bool is_bf16_output);
std::vector<at::Tensor>
hg_fwd_kvcache_bshd(at::Tensor &q,
......@@ -124,6 +137,7 @@ hg_prefix_prefill_varlen_fwd(
c10::optional<at::Tensor> scales_q_,
c10::optional<at::Tensor> scales_k_,
c10::optional<at::Tensor> scales_v_,
c10::optional<at::Tensor> s_aux_,
const bool is_bf16_output);
std::vector<at::Tensor>
......@@ -147,7 +161,27 @@ hg_prefix_decode_varlen_fwd(
int window_size_right,
const float softcap,
const bool return_softmax,
const int layout);
const int layout,
c10::optional<at::Tensor> scales_q_,
c10::optional<at::Tensor> scales_k_,
c10::optional<at::Tensor> scales_v_,
c10::optional<at::Tensor> s_aux_,
const bool is_bf16_output);
std::vector<at::Tensor>
hg_fwd_kvcache_mla(
at::Tensor &q_all,
at::Tensor &kvcache,
c10::optional<const at::Tensor> &vcache_,
const int headdim_v,
const at::Tensor &seqlens_k,
const at::Tensor &block_table,
const float softmax_scale,
const bool is_causal,
const c10::optional<const at::Tensor> &tile_scheduler_metadata,
const c10::optional<const at::Tensor> &num_splits,
c10::optional<at::Tensor> &out_,
int max_seqlen_k);
std::vector<at::Tensor>
hg_bwd_bshd(const at::Tensor &dout,
......@@ -229,6 +263,55 @@ static const bool enable_hg_varlen = get_env_("FLASH_ATTENTION_ENABLE_HG_VARLEN"
#ifdef HAS_HG_DISPATCH
static inline bool is_gfx92a() {
return get_device_name() == "gfx92a";
}
static inline bool is_hg_f16bf16(const at::ScalarType dtype) {
return dtype == torch::kFloat16 || dtype == torch::kBFloat16;
}
static inline bool is_hg_legacy_head_dim(const int head_size_qk, const int head_size_v) {
return (head_size_qk == 64 && head_size_v == 64)
|| (head_size_qk == 128 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 128);
}
static inline bool is_hg_regular_head_dim(const int head_size_qk, const int head_size_v) {
return (head_size_qk == head_size_v
&& (head_size_qk == 32 || head_size_qk == 64 || head_size_qk == 96
|| head_size_qk == 128 || head_size_qk == 160 || head_size_qk == 192
|| head_size_qk == 224 || head_size_qk == 256 || head_size_qk == 512))
|| (head_size_qk == 192 && head_size_v == 128);
}
static inline bool is_hg_regular_varlen_head_dim(const int head_size_qk, const int head_size_v) {
return is_hg_regular_head_dim(head_size_qk, head_size_v) && head_size_qk <= 256;
}
static inline bool is_hg_prefix_head_dim(const int head_size_qk, const int head_size_v) {
return (head_size_qk == 128 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 192)
|| (head_size_qk == 256 && head_size_v == 256);
}
static inline bool is_hg_pa_head_dim(const int head_size_qk, const int head_size_v) {
return is_hg_regular_varlen_head_dim(head_size_qk, head_size_v)
|| (head_size_qk == 576 && head_size_v == 512);
}
static inline std::vector<at::Tensor> make_hg_varlen_fwd_result(std::vector<at::Tensor> hg_result) {
std::vector<at::Tensor> result(8);
if (!hg_result.empty()) {
result[0] = hg_result[0];
}
if (hg_result.size() > 1) {
result[5] = hg_result[1];
}
return result;
}
static inline bool can_use_hg_dense_fwd(
const at::ScalarType q_dtype,
const int head_size_qk,
......@@ -242,17 +325,25 @@ static inline bool can_use_hg_dense_fwd(
const int seqlen_k,
const int window_size_left,
const int window_size_right) {
return get_device_name() == "gfx938"
const auto device_name = get_device_name();
if (device_name == "gfx92a") {
return p_dropout == 0.f
&& !alibi_slopes_.has_value()
&& !s_aux_.has_value()
&& skip_softmax_threshold_scale_factor <= 0.f
&& is_hg_f16bf16(q_dtype)
&& is_hg_regular_head_dim(head_size_qk, head_size_v)
&& seqlen_k > 0;
}
return device_name == "gfx938"
&& p_dropout == 0.f
&& !alibi_slopes_.has_value()
&& !s_aux_.has_value()
&& skip_softmax_threshold_scale_factor <= 0.f
&& (!is_causal || seqlen_q == seqlen_k)
&& window_size_left < 0 && window_size_right <= 0
&& (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16)
&& ((head_size_qk == 64 && head_size_v == 64)
|| (head_size_qk == 128 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 128))
&& is_hg_f16bf16(q_dtype)
&& is_hg_legacy_head_dim(head_size_qk, head_size_v)
&& seqlen_k > 0;
}
......@@ -271,7 +362,21 @@ static inline bool can_use_hg_varlen_fwd(
const int max_seqlen_k,
const int window_size_left,
const int window_size_right) {
return get_device_name() == "gfx938"
const auto device_name = get_device_name();
if (device_name == "gfx92a") {
return p_dropout == 0.f
&& !paged_kv
&& !leftpad_k_.has_value()
&& !alibi_slopes_.has_value()
&& !s_aux_.has_value()
&& !q_descale_.has_value()
&& !k_descale_.has_value()
&& !v_descale_.has_value()
&& is_hg_f16bf16(q_dtype)
&& is_hg_regular_varlen_head_dim(head_size_qk, head_size_v)
&& max_seqlen_k > 0;
}
return device_name == "gfx938"
&& p_dropout == 0.f
&& !paged_kv
&& !leftpad_k_.has_value()
......@@ -281,13 +386,69 @@ static inline bool can_use_hg_varlen_fwd(
&& !k_descale_.has_value()
&& !v_descale_.has_value()
&& window_size_left < 0 && window_size_right < 0
&& (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16)
&& ((head_size_qk == 64 && head_size_v == 64)
|| (head_size_qk == 128 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 128))
&& is_hg_f16bf16(q_dtype)
&& is_hg_legacy_head_dim(head_size_qk, head_size_v)
&& max_seqlen_k > 0;
}
static inline bool can_use_hg_92a_prefix_fwd(
const at::ScalarType q_dtype,
const bool paged_kv,
const c10::optional<at::Tensor> &seqused_k,
const c10::optional<const at::Tensor> &leftpad_k_,
const c10::optional<at::Tensor> &alibi_slopes_,
const c10::optional<at::Tensor> &q_descale_,
const c10::optional<at::Tensor> &k_descale_,
const c10::optional<at::Tensor> &v_descale_,
const c10::optional<at::Tensor> &s_aux_,
const float p_dropout,
const int head_size_qk,
const int head_size_v,
const int max_seqlen_q,
const int max_seqlen_k,
const int page_block_size) {
return is_gfx92a()
&& paged_kv
&& seqused_k.has_value()
&& !leftpad_k_.has_value()
&& !alibi_slopes_.has_value()
&& !s_aux_.has_value()
&& !q_descale_.has_value()
&& !k_descale_.has_value()
&& !v_descale_.has_value()
&& p_dropout == 0.f
&& is_hg_f16bf16(q_dtype)
&& is_hg_prefix_head_dim(head_size_qk, head_size_v)
&& max_seqlen_q > 0
&& max_seqlen_k > 0
&& page_block_size == 128;
}
static inline bool can_use_hg_92a_kvcache_fwd(
const at::ScalarType q_dtype,
const bool paged_kv,
const c10::optional<const at::Tensor> &seqlens_k_,
const c10::optional<const at::Tensor> &cache_batch_idx_,
const c10::optional<const at::Tensor> &leftpad_k_,
const c10::optional<at::Tensor> &alibi_slopes_,
const c10::optional<at::Tensor> &s_aux_,
const int head_size_qk,
const int head_size_v,
const int max_seqlen_k,
const int page_block_size) {
return is_gfx92a()
&& paged_kv
&& seqlens_k_.has_value()
&& !cache_batch_idx_.has_value()
&& !leftpad_k_.has_value()
&& !alibi_slopes_.has_value()
&& !s_aux_.has_value()
&& is_hg_f16bf16(q_dtype)
&& is_hg_pa_head_dim(head_size_qk, head_size_v)
&& max_seqlen_k > 0
&& page_block_size == 128;
}
static inline bool can_use_hg_dense_bwd(
const at::ScalarType q_dtype,
const c10::optional<at::Tensor> &alibi_slopes_,
......@@ -304,10 +465,8 @@ static inline bool can_use_hg_dense_bwd(
&& !alibi_slopes_.has_value()
&& (!is_causal || seqlen_q == seqlen_k)
&& window_size_left < 0 && window_size_right < 0
&& (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16)
&& ((head_size_qk == 64 && head_size_v == 64)
|| (head_size_qk == 128 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 128))
&& is_hg_f16bf16(q_dtype)
&& is_hg_legacy_head_dim(head_size_qk, head_size_v)
&& seqlen_q > 0
&& seqlen_k > 0;
}
......@@ -327,10 +486,8 @@ static inline bool can_use_hg_varlen_bwd(
&& p_dropout == 0.f
&& !alibi_slopes_.has_value()
&& window_size_left < 0 && window_size_right < 0
&& (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16)
&& ((head_size_qk == 64 && head_size_v == 64)
|| (head_size_qk == 128 && head_size_v == 128)
|| (head_size_qk == 192 && head_size_v == 128))
&& is_hg_f16bf16(q_dtype)
&& is_hg_legacy_head_dim(head_size_qk, head_size_v)
&& total_q > 0
&& total_k > 0
&& max_seqlen_k > 0;
......@@ -979,14 +1136,18 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
? hg_fwd_bhsd(q, k, v, out_, alibi_slopes_,
p_dropout, softmax_scale, is_causal,
window_size_left, window_size_right,
softcap, false /*return_softmax*/, gen_)
softcap, false /*return_softmax*/, gen_,
c10::nullopt, c10::nullopt, c10::nullopt, false /*is_bf16_output*/)
: hg_fwd_bshd(q, k, v, out_, alibi_slopes_,
p_dropout, softmax_scale, is_causal,
window_size_left, window_size_right,
softcap, false /*return_softmax*/, gen_);
softcap, false /*return_softmax*/, gen_,
c10::nullopt, c10::nullopt, c10::nullopt, false /*is_bf16_output*/);
hg_result.push_back(at::Tensor());
return hg_result;
}
TORCH_CHECK(!is_gfx92a(),
"gfx92a HG dispatch supports fp16/bf16 dense BSHD forward only; this dense configuration is not supported");
#endif
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
......@@ -1231,6 +1392,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn,
"FlashAttention only support fp16 and bf16 or fp8 data type");
#ifdef HAS_HG_DISPATCH
TORCH_CHECK(!is_gfx92a() || is_hg_f16bf16(q.scalar_type()),
"gfx92a HG dispatch supports fp16/bf16 forward only; fp8/int8 varlen forward is not supported");
#endif
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
// }
......@@ -1300,6 +1465,40 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
q_descale_, k_descale_, v_descale_, s_aux_);
#ifdef HAS_HG_DISPATCH
if (!use_varlen_tiny_dim64
&& can_use_hg_92a_prefix_fwd(
q.scalar_type(), paged_KV, seqused_k, leftpad_k_, alibi_slopes_,
q_descale_, k_descale_, v_descale_, s_aux_,
p_dropout, head_size_og, head_size_value, max_seqlen_q, max_seqlen_k,
page_block_size)) {
if (print_param || print_hg_path) {
printf("[flash_attn] HG PATH gfx92a prefix %s q=(%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d) batch_size=%d max_seqlen_q=%d max_seqlen_k=%d\n",
max_seqlen_q > 16 ? "prefill" : "decode",
(int)q.size(0), (int)q.size(1), (int)q.size(2),
(int)k.size(0), (int)k.size(1), (int)k.size(2), (int)k.size(3),
(int)v.size(0), (int)v.size(1), (int)v.size(2), (int)v.size(3),
batch_size, max_seqlen_q, max_seqlen_k);
}
at::Tensor seqused_k_tensor = seqused_k.value();
c10::optional<at::Tensor> hg_cu_seqlens_k = c10::nullopt;
auto hg_result = max_seqlen_q > 16
? hg_prefix_prefill_varlen_fwd(
q, k, v, out_, cu_seqlens_q, hg_cu_seqlens_k, seqused_k_tensor,
alibi_slopes_, block_table, max_seqlen_q, max_seqlen_k, 0.f,
softmax_scale, zero_tensors, is_causal, window_size_left,
window_size_right, softcap, return_softmax, 1 /*bshd*/,
c10::nullopt, c10::nullopt, c10::nullopt, s_aux_,
false /*is_bf16_output*/)
: hg_prefix_decode_varlen_fwd(
q, k, v, out_, cu_seqlens_q, hg_cu_seqlens_k, seqused_k_tensor,
alibi_slopes_, block_table, max_seqlen_q, max_seqlen_k, 0.f,
softmax_scale, zero_tensors, is_causal, window_size_left,
window_size_right, softcap, return_softmax, 1 /*bshd*/,
q_descale_, k_descale_, v_descale_, s_aux_,
false /*is_bf16_output*/);
return make_hg_varlen_fwd_result(std::move(hg_result));
}
if (!use_varlen_tiny_dim64
&& can_use_hg_varlen_fwd(
q.scalar_type(), paged_KV, leftpad_k_, alibi_slopes_,
......@@ -1317,14 +1516,17 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// matches HG's bshd layout semantics even though batch and seqlen are
// flattened into a single leading dimension.
// HG kernel does not support S_dmask output (ReturnSoftmaxConst=false).
auto hg_result = hg_varlen_fwd_bshd(q, k, v, out_,
auto hg_result = hg_varlen_fwd_bshd(q, const_cast<at::Tensor &>(k), const_cast<at::Tensor &>(v), out_,
cu_seqlens_q, cu_seqlens_k, seqused_k, alibi_slopes_,
max_seqlen_q, max_seqlen_k,
p_dropout, softmax_scale, zero_tensors, is_causal,
window_size_left, window_size_right,
softcap, false /*return_softmax*/, gen_);
softcap, false /*return_softmax*/, gen_,
q_descale_, k_descale_, v_descale_, false /*is_bf16_output*/);
return hg_result;
}
TORCH_CHECK(!is_gfx92a(),
"gfx92a HG dispatch supports fp16/bf16 dense, standard varlen, and BSHD paged/prefix forward only; this varlen configuration is not supported");
#endif
if (seqlenq_ngroups_swapped) {
......@@ -1957,6 +2159,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
#ifdef HAS_HG_DISPATCH
TORCH_CHECK(!is_gfx92a(), "gfx92a HG dispatch supports forward only; backward is not supported");
#endif
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
// }
......@@ -2250,6 +2455,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
#ifdef HAS_HG_DISPATCH
TORCH_CHECK(!is_gfx92a(), "gfx92a HG dispatch supports forward only; varlen backward is not supported");
#endif
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
// }
......@@ -2575,6 +2783,39 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; }
#ifdef HAS_HG_DISPATCH
if (can_use_hg_92a_kvcache_fwd(
q.scalar_type(), paged_KV, seqlens_k_, cache_batch_idx_, leftpad_k_,
alibi_slopes_, s_aux_, head_size_og, head_size_value, seqlen_k,
page_block_size)) {
int hg_window_size_left = window_size_left;
int hg_window_size_right = window_size_right;
if (hg_window_size_left >= seqlen_k) { hg_window_size_left = -1; }
if (hg_window_size_right >= seqlen_k) { hg_window_size_right = -1; }
if (print_param || print_hg_path) {
printf("[flash_attn] HG PATH gfx92a kvcache bshd q=(%d,%d,%d,%d) kcache=(%d,%d,%d,%d) vcache=(%d,%d,%d,%d) max_seqlen_k=%d\n",
(int)q.size(0), (int)q.size(1), (int)q.size(2), (int)q.size(3),
(int)kcache.size(0), (int)kcache.size(1), (int)kcache.size(2), (int)kcache.size(3),
(int)vcache.size(0), (int)vcache.size(1), (int)vcache.size(2), (int)vcache.size(3),
seqlen_k);
}
c10::optional<const at::Tensor> hg_seqlens_q = c10::nullopt;
c10::optional<at::Tensor> hg_scores_raw = c10::nullopt;
c10::optional<at::Tensor> hg_tmp_output = c10::nullopt;
auto hg_result = hg_fwd_kvcache_bshd(
q, kcache, vcache, k_, v_, hg_seqlens_q, seqlens_k_, seqlen_k,
rotary_cos_, rotary_sin_, cache_batch_idx_, leftpad_k_, block_table_,
alibi_slopes_, out_, softmax_scale, is_causal, hg_window_size_left,
hg_window_size_right, softcap, is_rotary_interleaved, num_splits,
hg_scores_raw, hg_tmp_output, c10::nullopt, c10::nullopt, c10::nullopt,
false /*is_bf16_output*/);
TORCH_CHECK(!hg_result.empty(), "gfx92a HG kvcache dispatch returned no tensors");
return {hg_result[0], at::Tensor()};
}
TORCH_CHECK(!is_gfx92a(),
"gfx92a HG dispatch supports fp16/bf16 BSHD paged KV-cache forward only; this kvcache configuration is not supported");
#endif
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// 测试会core dump,暂时写死
......@@ -2881,6 +3122,10 @@ vllm_mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x n
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
#ifdef HAS_HG_DISPATCH
TORCH_CHECK(!is_gfx92a(),
"gfx92a HG dispatch does not support vLLM KV-cache layout; use BSHD paged KV-cache via fwd_kvcache/hg_fwd_kvcache_bshd");
#endif
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
// }
......@@ -3237,6 +3482,9 @@ vllm_mha_varlen_fwd_kv_fp8(at::Tensor &q, // total_q x num_heads x head_size, t
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16,
"FlashAttention only support fp16 and bf16 data type");
#ifdef HAS_HG_DISPATCH
TORCH_CHECK(!is_gfx92a(), "gfx92a HG dispatch supports fp16/bf16 forward only; fp8/int8 vLLM varlen forward is not supported");
#endif
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
// }
......@@ -3598,6 +3846,10 @@ vllm_mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn || q_dtype == torch::kFloat8_e5m2,
"FlashAttention only support fp16 and bf16 or fp8 e4m3 or fp8 e5m2 data type");
#ifdef HAS_HG_DISPATCH
TORCH_CHECK(!is_gfx92a(),
"gfx92a HG dispatch does not support vLLM varlen layout; use BSHD paged/prefix HG dispatch");
#endif
// if (q_dtype == torch::kBFloat16) {
// TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
// }
......@@ -4464,7 +4716,7 @@ TORCH_LIBRARY_IMPL(flash_attn2_c_op, CUDA, m) {
return std::make_tuple(results[0], results[1]);
});
}
at::Tensor mean_pool_fast(const at::Tensor &input,int blk,const c10::optional<at::Tensor> &mean);
// ============================================================================
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......@@ -4479,6 +4731,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vllm_mha_varlen_fwd_kv_fp8", &vllm_mha_varlen_fwd_kv_fp8, "Forward pass, with KV-cache");
#ifdef HAS_HG_DISPATCH
m.def("hg_fwd_kvcache_bshd", &hg_fwd_kvcache_bshd, "HG forward pass, with KV-cache");
m.def("hg_fwd_kvcache_mla", &hg_fwd_kvcache_mla, "HG forward pass, with FlashMLA KV-cache");
m.def("hg_prefix_prefill_varlen_fwd", &hg_prefix_prefill_varlen_fwd, "HG prefix prefill forward pass (variable length)");
m.def("hg_prefix_decode_varlen_fwd", &hg_prefix_decode_varlen_fwd, "HG prefix decode forward pass (variable length)");
#endif
......@@ -4489,7 +4742,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("varlen_bwd_attnmask", &mha_varlen_bwd_attnmask, "Backward pass (variable length), with explicit attention mask");
m.def("paged_attention", &paged_attention, "Forward pass, with KV-cache");
m.def("fwd_sparse", &mha_fwd_sparse, "Forward sparse pass");
m.def("fwd_sparse_mean_pool_fast", &mean_pool_fast, "before mha_fwd_sparse");
m.def("varlen_fwd_sparse", &mha_varlen_fwd_sparse, "Forward pass sparse (variable length)");
m.def("varlen_fwd_unified", &unified2D_attention_fwd, "Forward pass unified attn (variable length && block table)");
}
......@@ -386,6 +386,66 @@ struct Dropout {
}
}
template <bool encode_dropout_in_sign_bit = false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout_trans_dim64_opt(
Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride)
{
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
if constexpr (encode_dropout_in_sign_bit) {
return keep ? val : -val;
} else {
return keep ? val : T(0);
}
};
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (threadIdx.x / 64) * 16 + lane_id % 16;
extern __shared__ char smem_[];
uint8_t *p_rand_8 = reinterpret_cast<uint8_t *>(smem_ + 16384);
// write
int row_ = (threadIdx.x % 16) + (threadIdx.x / 64) * 16;
int col_ = (lane_id / 16) * 16;
// read
const int read_row = (lane_id / 16) * 4;
const int lane_group = (lane_id % 16) / 4;
const int lane_offset = lane_id % 4;
const int read_col = (threadIdx.x / 64) * 4 + lane_group * 16 + lane_offset;
// padding stride
// constexpr int RAND_STRIDE = 64 + 4;
constexpr int RAND_STRIDE = 64;
for (int i = 0; i < size<1>(tensor); ++i) {
const int row_idx_base = block_row_start + i * block_row_stride + (lane_id / 16) * 4;
uint2 rowcol = make_uint2(col_idx_offset, row_idx_base);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long &>(rowcol), offset);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
*reinterpret_cast<uint4*>(&p_rand_8[row_ * RAND_STRIDE + col_]) = random_uint4;
// __syncthreads();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int j = 0; j < size<2>(tensor); ++j) {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int rand_read_row = read_row + j * 16 + mi;
const uint8_t t_rand = p_rand_8[(rand_read_row) * RAND_STRIDE + read_col];
tensor(mi, i, j) =
encode_dropout(t_rand <= p_dropout_in_uint8_t, tensor(mi, i, j));
}
}
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
}
}
};
......
......@@ -206,6 +206,7 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ mm_prefix_range_ptr;
int max_mm_ranges = 0;
bool use_alibi_sqrt = false;
int se_balance_cnt;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -6139,14 +6139,22 @@ inline __device__ void compute_dq_1rowblock_16x64_dim256_prefetch(const Params &
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
constexpr static int K_BUFF_SIZE = 4;
// __syncthreads();
int n_block = n_block_max - 1;
s_waitcnt<0>();
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; i++) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
 
 
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
......@@ -6165,43 +6173,75 @@ inline __device__ void compute_dq_1rowblock_16x64_dim256_prefetch(const Params &
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_s_ori);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 0);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(0, gK, sK, 4, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 4, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<4, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 1);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(1, gK, sK, 5, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 5, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<5, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 2);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(2, gK, sK, 6, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 6, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<6, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 3);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(3, gK, sK, 7, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 7, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<7, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 4, 0);
s_barrier();
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 5, 1);
s_barrier();
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 1, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 6, 2);
s_barrier();
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 2, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 7, 3);
s_barrier();
......@@ -6303,42 +6343,71 @@ inline __device__ void compute_dq_1rowblock_16x64_dim256_prefetch(const Params &
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_dp_ori);
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 3, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 0);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(0, gV, sV, 4, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 4, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<4, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 1);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(1, gV, sV, 5, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 5, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<5, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 2);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(2, gV, sV, 6, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 6, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<6, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 3);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(3, gV, sV, 7, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gV, sV, 7, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<7, K_BUFF_SIZE, Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 4, 0);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 0, gK, sKt, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 5, 1);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 1, gK, sKt, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 6, 2);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 2, gK, sKt, 2, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrdO, tdPrV, tdPsV, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 7, 3);
s_barrier();
......@@ -6353,64 +6422,152 @@ inline __device__ void compute_dq_1rowblock_16x64_dim256_prefetch(const Params &
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#if 1
{
using __float2 = __attribute__((ext_vector_type(2))) float;
static_assert(decltype(size<1>(dS))::value == 16);
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
dS(mi, ni) = scaled_ds;
const float d = dP_sum(mi);
__float2 d_pair = {-d, -d};
for (int ni = 0; ni < 4; ni ++) {
__float2 scores_pair_0 = {scores(mi, ni), scores(mi, ni + 4)};
__float2 scores_pair_1 = {scores(mi, ni + 8), scores(mi, ni + 12)};
__float2 dS_pair_0;
__float2 dS_pair_1;
if constexpr (!Is_dropout)
{
dS_pair_0 = {dS(mi, ni), dS(mi, ni + 4)};
dS_pair_1 = {dS(mi, ni + 8), dS(mi, ni + 12)};
dS_pair_0 = __builtin_hcu_pk_add_f32(dS_pair_0, d_pair);
dS_pair_1 = __builtin_hcu_pk_add_f32(dS_pair_1, d_pair);
}
else
{
dS_pair_0 = {
!Is_dropout || scores_pair_0.x >= 0 ? dS(mi, ni) - d : d,
!Is_dropout || scores_pair_0.y >= 0 ? dS(mi, ni + 4) - d : d
};
dS_pair_1 = {
!Is_dropout || scores_pair_1.x >= 0 ? dS(mi, ni + 8) - d : d,
!Is_dropout || scores_pair_1.y >= 0 ? dS(mi, ni + 12) - d : d
};
}
__float2 scaled_ds_0 = __builtin_hcu_pk_mul_f32(scores_pair_0, dS_pair_0);
__float2 scaled_ds_1 = __builtin_hcu_pk_mul_f32(scores_pair_1, dS_pair_1);
dS(mi, ni) = scaled_ds_0.x;
dS(mi, ni + 4) = scaled_ds_0.y;
dS(mi, ni + 8) = scaled_ds_1.x;
dS(mi, ni + 12) = scaled_ds_1.y;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(dS); ni += 2) {
// const float d = dP_sum(mi);
// __float2 scores_pair = {scores(mi, ni), scores(mi, ni + 1)};
// __float2 dS_pair;
// if constexpr (!Is_dropout)
// {
// __float2 d_pair = {-d, -d};
// dS_pair = {dS(mi, ni), dS(mi, ni + 1)};
// dS_pair = __builtin_hcu_pk_add_f32(dS_pair, d_pair);
// }
// else
// {
// dS_pair = {
// !Is_dropout || scores_pair.x >= 0 ? dS(mi, ni) - d : d,
// !Is_dropout || scores_pair.y >= 0 ? dS(mi, ni + 1) - d : d
// };
// }
// __float2 scaled_ds = __builtin_hcu_pk_mul_f32(scores_pair, dS_pair);
// dS(mi, ni) = scaled_ds.x;
// dS(mi, ni + 1) = scaled_ds.y;
// }
}
}
// #pragma unroll
// for (int mi = 0; mi < size<0>(dS); ++mi) {
// #pragma unroll
// for (int ni = 0; ni < size<1>(dS); ++ni) {
// float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
// if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
// dS(mi, ni) = scaled_ds;
// }
// }
#endif
 
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
Tensor tdQrdS = flash::convert_type<Element>(dS_reshaped);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 3, gK, sKt, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 0, gK, sKt, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 1, gK, sKt, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 2, gK, sKt, 2, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dq_0_127, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 3, gK, sKt, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
 
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
 
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
 
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dq_128_256, tdQrdS, tdQrKt, tdQsKt, tiled_mma_dq, smem_tiled_copy_Kt, smem_thr_copy_Kt);
s_barrier();
 
if (n_block > n_block_min) {
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; i ++) {
lds_direct_copy<Is_even_K>(gK, sK, i, params.k_row_stride, params.d);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
}
 
......@@ -6447,10 +6604,31 @@ inline __device__ void compute_dq_1rowblock_16x64_dim256_prefetch(const Params &
{
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM)
{
int row = (m*kBlockM + warpId) * 16 + (laneId % 16);
#pragma unroll
for (int k = 0; k < size<2>(taccdQrdQ); k++)
{
const int col_id = get<1>(tdQcdQ(0, 0, k));
if constexpr (Is_even_K)
{
int col = (laneId / 16) * 2 + k * 32;
for (int ei = 0; ei < 4; ++ei)
{
using __float2 = __attribute__((ext_vector_type(2))) float;
__float2 scale_softmax_rp_dropoutx2 = {params.scale_softmax_rp_dropout, params.scale_softmax_rp_dropout};
__float2 acc_dqx2 = {acc_dq(ei, m, k), acc_dq(ei + 4, m, k)};
__float2 resx2 = __builtin_hcu_pk_mul_f32(acc_dqx2, scale_softmax_rp_dropoutx2);
using result_type = cutlass::Array<Element, 2>;
result_type res;
res[0] = flash::convert_type<Element>(resx2[0]);
res[1] = flash::convert_type<Element>(resx2[1]);
*(result_type*)(&gdQ(row, col)) = res;
col += 8;
}
}
else
{
for (int i = 0; i < size<0>(taccdQrdQ); i++)
{
if (Is_even_K || col_id + i * 4 < params.d) {
......@@ -6460,6 +6638,7 @@ inline __device__ void compute_dq_1rowblock_16x64_dim256_prefetch(const Params &
}
}
}
}
 
#elif 0
#pragma unroll
......@@ -8124,7 +8303,7 @@ inline __device__ void compute_dq_1rowblock_16x64_dim64_prefetch(const Params &p
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
 
// __syncthreads();
__syncthreads();
int n_block = n_block_max - 1;
 
 
......@@ -8138,11 +8317,22 @@ inline __device__ void compute_dq_1rowblock_16x64_dim64_prefetch(const Params &p
// {
// lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
// }
if constexpr(Is_even_K)
{
lds_direct_copy_even_k<0, /*Is_even_MN=*/Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<0, /*Is_even_MN=*/Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/Is_even_MN>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
else
{
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
 
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gV, sV, 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/Is_even_MN>(gV, sV, 1, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
 
 
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
......@@ -8181,11 +8371,21 @@ inline __device__ void compute_dq_1rowblock_16x64_dim64_prefetch(const Params &p
flash::gemm_k_rs(acc_s_ori, tSrQ, tSrK, tSsK, tiled_mma_sdp, smem_tiled_copy_KV, smem_thr_copy_KV, 1);
 
asm volatile("s_barrier");
if constexpr(Is_even_K)
{
lds_direct_copy_even_k<0, Is_even_MN, _16x64_64>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, Is_even_MN, _16x64_64>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<2, Is_even_MN, _16x64_64>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<3, Is_even_MN, _16x64_64>(gK, sKt, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
else
{
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 2, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gK, sKt, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
asm volatile("s_barrier");
 
Tensor acc_s = make_tensor(acc_s_ori.data(), flash::convert_layout_acc(acc_s_ori.layout()));
......@@ -8339,14 +8539,22 @@ inline __device__ void compute_dq_1rowblock_16x64_dim64_prefetch(const Params &p
if (n_block > n_block_min) {
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
if constexpr (Is_even_K)
{
lds_direct_copy_even_k<0, /*Is_even_MN=*/true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<0, /*Is_even_MN=*/true>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/true>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
else
{
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true>(gV, sV, 0, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true>(gV, sV, 1, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
 
}
......@@ -8413,7 +8621,7 @@ inline __device__ void compute_dq_seqq_parallel_16x64_prefetch(const Params &par
}
#else
int m_block = blockIdx.x;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
using Element = typename Kernel_traits::Element;
 
if constexpr (Kernel_traits::kHeadDim == 192 && Kernel_traits::kHeadDimV == 128)
......@@ -8465,17 +8673,31 @@ inline __device__ void compute_dq_seqq_parallel_16x64_prefetch(const Params &par
}
 
if constexpr (Kernel_traits::kHeadDim == 64) {
compute_dq_1rowblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, m_block);
#ifndef NO_CAUSAL_OPT
if constexpr (Is_causal)
if constexpr(Is_causal)
{
const int bidbh = blockIdx.x + blockIdx.z * params.se_balance_cnt;
const int bidb = bidbh / params.h;
const int bidh = bidbh % params.h;
const int num_blocks = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
if (num_blocks - m_block - 1 != m_block)
{
compute_dq_1rowblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_blocks - m_block - 1);
const int m_block = num_blocks - 1 - blockIdx.y;
compute_dq_1rowblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, m_block);
}
else
{
compute_dq_1rowblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, m_block);
}
#endif
// #ifndef NO_CAUSAL_OPT
// if constexpr (Is_causal)
// {
// const int num_blocks = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// if (num_blocks - m_block - 1 != m_block)
// {
// compute_dq_1rowblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_blocks - m_block - 1);
// }
// }
// #endif
return;
}
......@@ -10761,6 +10983,16 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
constexpr int kdP_loops = size<2>(tdPsdO);
constexpr int kdK_loops = size<2>(tdKsQt);
// static_assert(kStages <= kS_loops && kStages <= kdV_loops && kStages <= kdP_loops && kStages <= kdK_loops, "kStages is error");
if constexpr(Is_even_K) {
lds_direct_copy_even_k<0, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<1, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<0, Is_even_MN, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<1, Is_even_MN, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<2, Is_even_MN, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<3, Is_even_MN, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
 
......@@ -10769,6 +11001,8 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gdO, sdOt, 2, params.do_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gdO, sdOt, 3, params.do_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
 
}
#pragma unroll
for (; m_block >= m_block_min; m_block--) {
 
......@@ -10924,7 +11158,7 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
int block_row_idx = row_idx_offset_;
int block_col_idx = m_block * kBlockM;
if constexpr (kHeadDim==64){
dropout.template apply_dropout_trans_opt</*encode_dropout_in_sign_bit=*/true>(
dropout.template apply_dropout_trans_dim64_opt</*encode_dropout_in_sign_bit=*/true>(
acc_s, n_block * kBlockN, m_block * kBlockM, kNWarps * 16
);
}else{
......@@ -10938,9 +11172,14 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
Tensor rP = !Is_dropout
? flash::convert_type<Element>(acc_s)
: flash::convert_type_relu<Element>(acc_s);
if constexpr(Is_even_K) {
lds_direct_copy_even_k<0, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<1, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
}
 
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
......@@ -10951,11 +11190,17 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dv, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
asm volatile("s_barrier");
if constexpr(Is_even_K) {
lds_direct_copy_even_k<0, Is_even_MN, _16x64_64>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<1, Is_even_MN, _16x64_64>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<2, Is_even_MN, _16x64_64>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<3, Is_even_MN, _16x64_64>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 2, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gQ, sQt, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
}
// return;
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_dp_ori);
......@@ -11027,6 +11272,15 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
gQ.data() = gQ.data() + (-int(kBlockM * params.q_row_stride));
gdO.data() = gdO.data() + (-int(kBlockM * params.do_row_stride));
 
if constexpr(Is_even_K) {
lds_direct_copy_even_k<0, true>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<1, true>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<0, true, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<1, true, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<2, true, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k<3, true, _16x64_64>(gdO, sdOt, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy<Is_even_K, true>(gQ, sQ, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true>(gQ, sQ, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
......@@ -11034,6 +11288,7 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim64_prefetch(const
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, true, _16x64_64>(gdO, sdOt, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
}
 
}
......@@ -12528,7 +12783,7 @@ __builtin_amdgcn_s_barrier();
 
#if 1
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const Params &params, const int bidb, const int bidh, const int n_block) {
inline __device__ void compute_dk_trans_1colblock_16x64_dim256_prefetch(const Params &params, const int bidb, const int bidh, const int n_block) {
 
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
......@@ -12694,9 +12949,9 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKcdK, tdKpdK, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN
);
// flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
// gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN
// );
return;
}
......@@ -12715,18 +12970,18 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
 
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDimV>>{});
// Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDimV>>{});
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});
 
Tensor acc_dv_split = local_tile(acc_dv, Shape<Int<8>, Int<1>, Int<kHeadDimV / 32 / 2>>{}, make_coord(0, 0, _));
// Tensor acc_dv_split = local_tile(acc_dv, Shape<Int<8>, Int<1>, Int<kHeadDimV / 32 / 2>>{}, make_coord(0, 0, _));
Tensor acc_dk_split = local_tile(acc_dk, Shape<Int<8>, Int<1>, Int<kHeadDim / 32 / 2>>{}, make_coord(0, 0, _));
 
auto acc_dv_0_128 = acc_dv_split(_, _, _, 0);
auto acc_dv_128_256 = acc_dv_split(_, _, _, 1);
// auto acc_dv_0_128 = acc_dv_split(_, _, _, 0);
// auto acc_dv_128_256 = acc_dv_split(_, _, _, 1);
 
auto acc_dk_0_128 = acc_dk_split(_, _, _, 0);
auto acc_dk_128_256 = acc_dk_split(_, _, _, 1);
clear(acc_dv);
// clear(acc_dv);
clear(acc_dk);
Tensor taccScS_row = taccScS(_, 0, _);
......@@ -12736,15 +12991,23 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
constexpr static int K_BUFF_SIZE = 4;
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
s_waitcnt<0>();
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
// wangaq debug
// s_waitcnt<0>();
// if (thread0() && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
......@@ -12763,43 +13026,70 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const
 
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{});
clear(acc_s_ori);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 0);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(0, gQ, sQ, 4, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<4, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 1);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(1, gQ, sQ, 5, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<5, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 2);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(2, gQ, sQ, 6, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<6, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 3);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(3, gQ, sQ, 7, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<7, K_BUFF_SIZE, Is_even_MN>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 4, 0);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 0, gdO, sdOt, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 5, 1);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 1, gdO, sdOt, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 6, 2);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 2, gdO, sdOt, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 7, 3);
s_barrier();
......@@ -12959,171 +13249,657 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const
? flash::convert_type<Element>(acc_s)
: flash::convert_type_relu<Element>(acc_s);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 3, gdO, sdOt, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 0, gdO, sdOt, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_dp_ori);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 0);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(0, gdO, sdO, 4, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<4, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 1);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(1, gdO, sdO, 5, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<5, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 2);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(2, gdO, sdO, 6, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<6, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 3);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(3, gdO, sdO, 7, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<7, K_BUFF_SIZE, Is_even_MN>(gdO, sdO, params.do_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 4, 0);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 0, gQ, sQt, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 5, 1);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 1, gQ, sQt, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 6, 2);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 2, gQ, sQt, 2, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 7, 3);
s_barrier();
Tensor acc_dp = make_tensor(acc_dp_ori.data(), convert_layout_acc(acc_dp_ori.layout()));
Tensor dS = make_tensor(acc_dp.data(), scores_trans.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores_trans(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh_trans(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}
Tensor tdKrdSt = flash::convert_type<Element>(acc_dp);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 3, gQ, sQt, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 0, gQ, sQt, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 1, gQ, sQt, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 2, gQ, sQt, 2, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 3, gQ, sQt, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gQ, sQt, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
if (m_block > m_block_min) {
gQ.data() = gQ.data() + (-int(kBlockM * params.q_row_stride));
gdO.data() = gdO.data() + (-int(kBlockM * params.do_row_stride));
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; ++i) {
lds_direct_copy<Is_even_K, true>(gQ, sQ, i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, true>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, true>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, true>(gQ, sQ, params.q_row_stride, binfo.actual_seqlen_q - m_block * kBlockM);
}
}
#endif
}
#if 0
#else
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
int row, col;
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) * 2 + ni * 32;
for (int ei = 0; ei < 4; ++ei) {
using result_type = cutlass::Array<Element, 2>;
if constexpr (Is_even_K)
{
result_type res;
res[0] = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
res[1] = flash::convert_type<Element>(acc_dk(ei + 4, mi, ni) * params.scale_softmax_rp_dropout);
*(result_type*)(&gdK(row, col)) = res;
// res[0] = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout);
// res[1] = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei + 4, mi, ni) : acc_dv(ei + 4, mi, ni) * params.rp_dropout);
// *(result_type*)(&gdV(row, col)) = res;
}
col += 8;
}
}
}
}
#endif
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dv_trans_1colblock_16x64_dim256_prefetch(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
extern __shared__ char smem_[];
const int tidx = threadIdx.x;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kStages = Kernel_traits::kStages;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if constexpr (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQdOGemm0{});
Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOGemm1transposed{});
Tensor sQtSplit = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransSplit{});
Tensor sdO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOGemm0{});
Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOGemm1transposed{});
Tensor sdOtSplit = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransSplit{});
// S/dP
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
Tensor tSrK = thr_mma_sdp.partition_fragment_A(gK);
Tensor tSrQ = thr_mma_sdp.partition_fragment_B(gQ);
Tensor tdPrV = thr_mma_sdp.partition_fragment_A(gV);
Tensor tdPrdO = thr_mma_sdp.partition_fragment_B(gdO);
// dV/dK
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sQt);
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sdOt);
//
// Copy Atom retiling
//
// S/dP
auto gmem_tiled_copy_KV = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tSgK = gmem_thr_copy_KV.partition_S(gK);
Tensor tdPgV = gmem_thr_copy_KV.partition_S(gV);
// auto smem_tiled_copy_QdO = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_B128, Element>{}, tiled_mma_sdp);
auto smem_tiled_copy_QdO = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma_sdp);
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
// dV/dK
auto smem_tiled_copy_QdOt = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_dkv);
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
Tensor tdVsdOt8x64 = smem_thr_copy_QdOt.partition_S(sdOtSplit);
Tensor tdVsdOt = make_tensor(tdVsdOt8x64.data(), convert_layout_B_rowcol_<_16x128, 4>(tdVsdOt8x64.layout()));
Tensor tdKsQt8x64 = smem_thr_copy_QdOt.partition_S(sQtSplit);
Tensor tdKsQt = make_tensor(tdKsQt8x64.data(), convert_layout_B_rowcol_<_16x128, 4>(tdKsQt8x64.layout()));
//
// PREDICATES
//
Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tKcK = gmem_thr_copy_KV.partition_D(cK);
Tensor tVcV = gmem_thr_copy_KV.partition_D(cV);
// Allocate predicate tensors for k
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tSgK)));
Tensor tVpV = make_tensor<bool>(make_shape(size<2>(tdPgV)));
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tVcV(0, 0, k)) < params.d_value; }
}
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 1, gdO, sdOt, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
int m_block = m_block_max - 1;
int m_block_min = (!Is_causal && !Is_local)
? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 2, gdO, sdOt, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdK = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor cdV = make_identity_tensor(make_shape(size<0>(gdV), size<1>(gdV))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKcdK = gmem_thr_copy_dKV.partition_D(cdK);
Tensor tdVcdV = gmem_thr_copy_dKV.partition_D(cdV);
Tensor tdKpdK = make_tensor<bool>(make_shape(size<2>(tdKcdK)));
Tensor tdVpdV = make_tensor<bool>(make_shape(size<2>(tdVcdV)));
#pragma unroll
for (int k = 0; k < size(tdKpdK); ++k) { tdKpdK(k) = get<1>(tdKcdK(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tdVpdV); ++k) { tdVpdV(k) = get<1>(tdVcdV(0, 0, k)) < params.d_value; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
// gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKcdK, tdKpdK, binfo.actual_seqlen_k - n_block * kBlockN
// );
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdVcdV, tdVpdV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 3, gdO, sdOt, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_KV, tSgK, tSrK, tKcK, tKpK, binfo.actual_seqlen_k - n_block * kBlockN
);
 
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_KV, tdPgV, tdPrV, tVcV, tVpV, binfo.actual_seqlen_k - n_block * kBlockN
);
 
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
Tensor caccS = make_identity_tensor(Shape<Int<kBlockN>, Int<kBlockM>>{}); // (BLK_N,BLK_M) -> (blk_n,blk_m)
Tensor taccScS = thr_mma_sdp.partition_C(caccS);
// return;
Tensor acc_dp_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_dp_ori);
flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
 
lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDimV>>{});
// Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});
Tensor acc_dv_split = local_tile(acc_dv, Shape<Int<8>, Int<1>, Int<kHeadDimV / 32 / 2>>{}, make_coord(0, 0, _));
// Tensor acc_dk_split = local_tile(acc_dk, Shape<Int<8>, Int<1>, Int<kHeadDim / 32 / 2>>{}, make_coord(0, 0, _));
auto acc_dv_0_128 = acc_dv_split(_, _, _, 0);
auto acc_dv_128_256 = acc_dv_split(_, _, _, 1);
// auto acc_dk_0_128 = acc_dk_split(_, _, _, 0);
// auto acc_dk_128_256 = acc_dk_split(_, _, _, 1);
clear(acc_dv);
// clear(acc_dk);
Tensor taccScS_row = taccScS(_, 0, _);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
s_waitcnt<0>();
#pragma unroll
for (int i = 0; i < 3; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, i, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
}
#pragma unroll
for (; m_block >= m_block_min; m_block--) {
Tensor acc_s_ori = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockN>, Int<kBlockM>>{});
clear(acc_s_ori);
lds_direct_copy<Is_even_K, Is_even_MN>(gQ, sQ, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 0);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 0);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN>(0, gdO, sdO, 4, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(0, gQ, sQ, 4, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 1);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 1);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN>(1, gdO, sdO, 5, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(1, gQ, sQ, 5, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 2);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 2);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN>(2, gdO, sdO, 6, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(2, gQ, sQ, 6, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 3);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 3);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN>(3, gdO, sdO, 7, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN>(3, gQ, sQ, 7, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 4, 0);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 4, 0);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 0, gQ, sQt, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 0, gdO, sdOt, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 5, 1);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 5, 1);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 1, gQ, sQt, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 1, gdO, sdOt, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 6, 2);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 6, 2);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 2, gQ, sQt, 2, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 2, gdO, sdOt, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs(acc_dp_ori, tdPrV, tdPrdO, tdPsdO, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 7, 3);
flash::gemm_k_rs(acc_s_ori, tSrK, tSrQ, tSsQ, tiled_mma_sdp, smem_tiled_copy_QdO, smem_thr_copy_QdO, 7, 3);
s_barrier();
Tensor acc_dp = make_tensor(acc_dp_ori.data(), convert_layout_acc(acc_dp_ori.layout()));
Tensor dS = make_tensor(acc_dp.data(), scores_trans.layout());
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("dP_sum tid:%d m_block:%d %10.4f %10.4f %10.4f %10.4f\n", tidx, m_block, dP_sum(0), dP_sum(1), dP_sum(2), dP_sum(3));
// float * tmp = reinterpret_cast<float*>(acc_dp.data());
// printf("dP tid:%d m_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// // printf("lse tid:%d m_block:%d %10.4f %10.4f %10.4f %10.4f\n", tidx, m_block, lse(0), lse(1), lse(2), lse(3));
// float * tmp = reinterpret_cast<float*>(acc_s.data());
// printf("acc_s tid:%d m_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, m_block,
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15]
// tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11], tmp[12], tmp[13], tmp[14], tmp[15]
// );
// }
 
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
Tensor scores_trans = make_tensor(acc_s.data(), flash::convert_trans_layout_acc_rowcol(acc_s.layout()));
if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}
[[maybe_unused]] Tensor dtanh_trans = make_tensor_like(scores_trans);
if constexpr (Is_softcap) {
flash::calculate_dtanh(scores_trans, dtanh_trans, params.softcap);
}
#if 1
if constexpr (Has_alibi) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = tidx / 64;
const int col_idx_offset = m_block * kBlockM;
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
alibi.apply_alibi_trans(scores, col_idx_offset, row_idx_offset_, kNWarps * 16);
}
#endif
#if 1
if constexpr(!Is_causal && !Is_local) {
if (!Is_even_MN && (m_block + 1) * kBlockM >= binfo.actual_seqlen_q) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int warp_id = tidx / 64;
// 实际上是row
const int col_idx_offset_ = m_block * kBlockM;
flash::apply_mask_trans(scores, binfo.actual_seqlen_q, col_idx_offset_);
}
} else if constexpr(Is_causal) {
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
// if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
// || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
// const int warp_id = tidx / 64;
// flash::apply_mask_causal(scores, n_block * kBlockN + (warp_id / AtomLayoutMS) * MMA_N_SdP * 16,
// binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_q,
// AtomLayoutMS * 16);
// }
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k)
{
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
flash::apply_mask_causal_trans(
scores,
m_block * kBlockM,
binfo.actual_seqlen_k,
row_idx_offset_,
binfo.actual_seqlen_q,
kNWarps * 16
);
}
} else if constexpr(Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left) {
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = n_block * kBlockN + row_idx_offset_in_block;
flash::apply_mask_local_trans(
scores,
m_block * kBlockM,
binfo.actual_seqlen_k,
row_idx_offset_,
binfo.actual_seqlen_q,
kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
}
#endif
#if 1
flash::scale_apply_exp2</*scale_max=*/false>(scores_trans, lse, params.scale_softmax_log2);
Tensor dP_sum = make_fragment_like(lse);
 
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
dP_sum(mi) = gdPsum(row);
}
if (m_block > m_block_min) {
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
gLSE.data() = gLSE.data() + (-int(kBlockM));
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores_trans(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) { scaled_ds *= dtanh_trans(mi, ni); }
dS(mi, ni) = scaled_ds;
for (int mi = 0; mi < size(lse); ++mi) {
const int row = (laneId / 16) * 4 + (mi % 4) + (mi / 4) * 16;
lse(mi) = gLSE(row);
}
}
 
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// float * tmp = reinterpret_cast<float*>(acc_dp.data());
// printf("dS tid:%d m_block:%d %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f "
// "%10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f %10.4f\n", tidx, m_block,
// tmp[0], tmp[1], tmp[2], tmp[3],
// tmp[4], tmp[5], tmp[6], tmp[7],
// tmp[8], tmp[9], tmp[10], tmp[11],
// tmp[12], tmp[13], tmp[14], tmp[15]
// );
// }
if constexpr (Is_dropout) {
const int warp_id = tidx / 64;
const int wave_id = (tidx >> 6);
const int wave_id_to_row_block_id = wave_id;
const int warp_row_stride = 16;
const int row_idx_offset_in_block = (tidx & (warp_row_stride - 1)) + (wave_id_to_row_block_id << 4);
const int row_idx_offset_ = (kHeadDim == 128) ? (n_block * kBlockN) : (n_block * kBlockN + row_idx_offset_in_block);
int block_row_idx = row_idx_offset_;
int block_col_idx = m_block * kBlockM;
if constexpr (kHeadDim==128){
dropout.template apply_dropout_trans_opt</*encode_dropout_in_sign_bit=*/true>(
acc_s, n_block * kBlockN, m_block * kBlockM, kNWarps * 16
);
}else{
dropout.template apply_dropout_trans</*encode_dropout_in_sign_bit=*/true>(
acc_s, block_row_idx, block_col_idx, kNWarps * 16
);
}
}
 
Tensor tdKrdSt = flash::convert_type<Element>(acc_dp);
Tensor rP = !Is_dropout
? flash::convert_type<Element>(acc_s)
: flash::convert_type_relu<Element>(acc_s);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 3, gQ, sQt, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 3, gdO, sdOt, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 0, gQ, sQt, 0, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 0, gdO, sdOt, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 1, gQ, sQt, 1, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 1, gdO, sdOt, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 2, gQ, sQt, 2, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 2, gdO, sdOt, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dk_0_128, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dv_0_128, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 3, gQ, sQt, 3, params.q_row_stride, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 3, gdO, sdOt, 3, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
 
// lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 0, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
 
// lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 1, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
 
// lds_direct_copy<Is_even_K, Is_even_MN>(gdO, sdO, 2, params.do_row_stride, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM);
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_dk_128_256, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_dv_128_256, rP, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);
s_barrier();
 
if (m_block > m_block_min) {
......@@ -13156,100 +13932,6 @@ inline __device__ void compute_dk_dv_trans_1colblock_16x64_dim256_prefetch(const
// }
 
#if 0
if constexpr(Is_dropout) {
#pragma unroll
for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }
}
#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dv from fp32 to fp16
Tensor rdK = flash::convert_type<Element>(acc_dk);
Tensor rdV = flash::convert_type<Element>(acc_dv);
// __syncthreads();
Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
// Partition sdV and sdK to match the accumulator partitioning
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// We need syncthreads here since we're writing to the same location as sK and sV.
// Without syncthreads, some thread might modify the location of sK while another thread
// is reading it for dQ gemm, leading to a race condition.
// If Is_last, there's already a __syncthreads() at the end of the loop.
// if constexpr(!Is_last) { __syncthreads(); }
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
__syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
__builtin_amdgcn_s_barrier();
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
__builtin_amdgcn_s_barrier();
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
__syncthreads();
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
__builtin_amdgcn_s_barrier();
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
__builtin_amdgcn_s_barrier();
#elif 0
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.dv_row_stride, _1{}));
_bwd_store_dk_dv<Kernel_traits, decltype(acc_dv), decltype(sQ), decltype(gdV), Element,
typename Kernel_traits::SmemLayoutdVStore, Is_even_MN, Is_even_K>(
acc_dv, sQ, tidx, gdV, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
__syncthreads();
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
_bwd_store_dk_dv<Kernel_traits, decltype(acc_dk), decltype(sQ), decltype(gdK), Element,
typename Kernel_traits::SmemLayoutdKStore, Is_even_MN, Is_even_K>(
acc_dk, sQ, tidx, gdK, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
#else
 
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
......@@ -13264,66 +13946,31 @@ __builtin_amdgcn_s_barrier();
make_stride(params.dv_row_stride, _1{}));
int row, col;
if constexpr (size<1>(acc_dk) == size<1>(acc_dv) && size<2>(acc_dk) == size<2>(acc_dv)) {
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dk); ++ei) {
if (Is_even_K || col < params.d) {
gdK(row, col) = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
gdV(row, col) = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout );
}
col += 4;
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dk); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dk); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dk); ++ei) {
if (Is_even_K || col < params.d) {
gdK(row, col) = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
}
col += 4;
}
}
}
}
#pragma unroll
for (int mi = 0; mi < size<1>(acc_dv); ++mi) {
row = (mi*kNWarps + warpId) * 16 + (laneId % 16);
if (Is_even_MN || row < binfo.actual_seqlen_k - n_block * kBlockN) {
#pragma unroll
for (int ni = 0; ni < size<2>(acc_dv); ++ni) {
col = (laneId / 16) + ni * 32;
#pragma unroll
for (int ei = 0; ei < size<0>(acc_dv); ++ei) {
if (Is_even_K || col < params.d) {
gdV(row, col) = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout);
}
col += 4;
col = (laneId / 16) * 2 + ni * 32;
for (int ei = 0; ei < 4; ++ei) {
using result_type = cutlass::Array<Element, 2>;
if constexpr (Is_even_K)
{
result_type res;
// res[0] = flash::convert_type<Element>(acc_dk(ei, mi, ni) * params.scale_softmax_rp_dropout);
// res[1] = flash::convert_type<Element>(acc_dk(ei + 4, mi, ni) * params.scale_softmax_rp_dropout);
// *(result_type*)(&gdK(row, col)) = res;
res[0] = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei, mi, ni) : acc_dv(ei, mi, ni) * params.rp_dropout);
res[1] = flash::convert_type<Element>(!Is_dropout ? acc_dv(ei + 4, mi, ni) : acc_dv(ei + 4, mi, ni) * params.rp_dropout);
*(result_type*)(&gdV(row, col)) = res;
}
col += 8;
}
}
}
}
#endif
}
#endif
 
......@@ -15899,33 +16546,68 @@ inline __device__ void compute_dk_dv_trans_16x64_prefetch(const Params &params)
}
}
} else if constexpr (Kernel_traits::kHeadDim == 64) {
if constexpr(Is_causal)
{
const int bidbh = blockIdx.x + blockIdx.z * params.se_balance_cnt;
const int bidb = bidbh / params.h;
const int bidh = bidbh % params.h;
const int n_block = blockIdx.y;
compute_dk_dv_trans_1colblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
}
else
{
compute_dk_dv_trans_1colblock_16x64_dim64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
}
}
#if 1
else if constexpr (Kernel_traits::kHeadDim == 256) {
compute_dk_dv_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
else if constexpr (Kernel_traits::kHeadDim == 512) {
compute_dk_dv_trans_1colblock_16x64_dim512_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
if constexpr (Is_causal)
{
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (num_n_block - n_block - 1 != num_n_block) {
compute_dk_dv_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
compute_dk_dv_trans_1colblock_16x64_dim512_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
}
}
}
#endif
}
 
#if 1
else if constexpr (Kernel_traits::kHeadDim == 512) {
compute_dk_dv_trans_1colblock_16x64_dim512_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dk_trans_16x64_prefetch(const Params &params) {
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
const int n_block = blockIdx.x;
using Element = typename Kernel_traits::Element;
compute_dk_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
if constexpr (Is_causal)
{
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (num_n_block - n_block - 1 != num_n_block) {
compute_dk_dv_trans_1colblock_16x64_dim512_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
// compute_dk_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
compute_dk_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
}
}
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dv_trans_16x64_prefetch(const Params &params) {
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
const int n_block = blockIdx.x;
using Element = typename Kernel_traits::Element;
compute_dv_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, n_block);
if constexpr (Is_causal)
{
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (num_n_block - n_block - 1 != num_n_block) {
compute_dv_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
// compute_dv_trans_1colblock_16x64_dim256_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params, bidb, bidh, num_n_block - n_block - 1);
}
}
#endif
}
 
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
......
......@@ -80,6 +80,22 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_prefetch, bool Is_dropo
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dv_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dv_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_mla_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
......@@ -357,6 +373,32 @@ void run_flash_bwd_separate_seqk_parallel_trans(Flash_bwd_params &params, cudaSt
#endif
}
static inline int calc_se_balance_cnt(int b_h_num)
{
int se_balance_cnt = 1;
if(b_h_num % 13 == 0){
se_balance_cnt = 13;
} else if(b_h_num % 9 == 0){
se_balance_cnt = 9;
} else if(b_h_num % 8 == 0){
se_balance_cnt = 8;
} else if(b_h_num % 7 ==0){
se_balance_cnt = 7;
} else if(b_h_num % 6 ==0){
se_balance_cnt = 6;
} else if(b_h_num % 5 ==0){
se_balance_cnt = 5;
} else if(b_h_num % 4 ==0){
se_balance_cnt = 4;
} else if(b_h_num % 3 ==0){
se_balance_cnt = 3;
} else if(b_h_num % 2 ==0){
se_balance_cnt = 2;
} else {
se_balance_cnt = 1;
}
return se_balance_cnt;
}
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stream) {
......@@ -370,19 +412,32 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre
const int num_n_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_n_block + 1 ) >> 1 :
non_causal_num_n_block;
const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_m_block = Is_causal ? (non_causal_num_m_block + 1 ) >> 1 :
const int num_m_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_m_block + 1 ) >> 1 :
non_causal_num_m_block;
#endif
dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
dim3 grid_n(num_n_block, params.h, params.b);
if constexpr (Kernel_trans_traits::kHeadDim == 64 && Is_causal)
{
int b_h_num = params.b * params.h;
params.se_balance_cnt = calc_se_balance_cnt(b_h_num);
grid_n.x = params.se_balance_cnt;
grid_n.y = num_n_block;
grid_n.z = (params.h * params.b/params.se_balance_cnt);
grid_m.x = params.se_balance_cnt;
grid_m.y = num_m_block;
grid_m.z = (params.h * params.b/params.se_balance_cnt);
}
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dropout = Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
constexpr int smem_size_dropout = Kernel_trans_traits::kHeadDim == 64 ? 4096 : Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128 || Kernel_trans_traits::kHeadDim == 64) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv);
constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
......@@ -397,6 +452,27 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// constexpr static bool Is_softcap = false;
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
if constexpr (Kernel_trans_traits::kHeadDim == 256) {
auto kernel = &flash_bwd_dv_trans_16x64_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
auto kernel_dk = &flash_bwd_dk_trans_16x64_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dk<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto kernel = &flash_bwd_dk_dv_trans_16x64_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
......@@ -407,6 +483,7 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel<
Kernel_traits,
......@@ -559,9 +636,9 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// }
// // printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938")
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a")
{
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/Is_dropout ? 64 : 128, /*kNWarps_*/4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/Is_dropout ? 128 : 128, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
......@@ -588,7 +665,7 @@ template<typename T, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, T, 3>;
......@@ -617,7 +694,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"){
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a"){
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
// using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? 64 : 128, /*kBlockN_*/64, /*kNWarps_*/4,
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? (Is_causal ? 64 : 128) : 128, /*kBlockN_*/64, /*kNWarps_*/4,
......@@ -686,7 +763,7 @@ void run_mha_bwd_hdim192_hdim128(Flash_bwd_params &params, cudaStream_t stream)
#if 1
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_mla_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
......@@ -782,7 +859,7 @@ template<typename T, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// printf("%s:%d\n", __FILE__, __LINE__);
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
......@@ -810,7 +887,7 @@ template<typename T, bool Is_causal>
void run_mha_bwd_hdim512(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 512;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// printf("%s:%d\n", __FILE__, __LINE__);
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
......
......@@ -2324,7 +2324,7 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch(const Params &param
 
// Prologue
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tSgQ, tGrQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM);
// __syncthreads();
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
Tensor acc_o = partition_fragment_C(tiled_mma_for_gemm1, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
......@@ -2355,17 +2355,17 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch(const Params &param
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, kStages + i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
S_WAITCNT;
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
S_BARRIER;
s_barrier();
}
 
#pragma unroll
for (int i = 0; i < kStages; ++i) { // tail kStages
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gV, sV, i, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
S_WAITCNT;
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + i);
S_BARRIER;
s_barrier();
}
 
......@@ -2429,20 +2429,20 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch(const Params &param
}
 
lds_direct_copy<Is_even_K, Is_even_MN, _16x128>(gV, sV, kStages + 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
S_WAITCNT;
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
S_BARRIER;
s_barrier();
 
if (n_block > n_block_min) {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
......@@ -2468,17 +2468,17 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch(const Params &param
#pragma unroll
for (int i = 0; i < k0_loops - kStages; ++i) {
lds_direct_copy<Is_even_K>(gK, sK, kStages + i, params.k_row_stride, params.d);
S_WAITCNT;
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, i);
S_BARRIER;
s_barrier();
}
 
#pragma unroll
for (int i = 0; i < kStages; ++i) { // tail kStages
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x128>(gV, sV, i, params.v_row_stride, params.d_value);
S_WAITCNT;
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k0_loops - kStages + i);
S_BARRIER;
s_barrier();
}
// __builtin_amdgcn_sched_barrier(1);
......@@ -2538,20 +2538,20 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch(const Params &param
}
 
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x128>(gV, sV, kStages + 0, params.v_row_stride, params.d_value);
S_WAITCNT;
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
 
asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 2
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
// S_BARRIER;
// k = 3
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
S_BARRIER;
s_barrier();
 
if (n_block > n_block_min) {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
......@@ -3048,8 +3048,8 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch_fp8(const Params &p
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16)*4 + ni * 32;
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(0, mi, ni), 0, acc_o(1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(2, mi, ni), 0, acc_o(3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(0, mi, ni), acc_o(1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(2, mi, ni), acc_o(3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
 
......@@ -3062,8 +3062,8 @@ inline __device__ void compute_attn_1rowblock_16x64_prefetch_fp8(const Params &p
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(4, mi, ni), 0, acc_o(5, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(6, mi, ni), 0, acc_o(7, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(4, mi, ni), acc_o(5, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(6, mi, ni), acc_o(7, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -3691,6 +3691,18 @@ inline __device__ void compute_attn_1rowblock_16x64_dim64_prefetch(const Params
constexpr int k0_loops = size<2>(tSsK);
constexpr int k1_loops = size<2>(tOsVt);
static_assert(k0_loops == 2 && k1_loops == 4);
if constexpr(Is_even_K)
{
lds_direct_copy_even_k<0, /*Is_even_MN=*/Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<0, Is_even_MN, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, Is_even_MN, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<2, Is_even_MN, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<3, Is_even_MN, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
else
{
#pragma unroll
for (int i = 0; i < k0_loops; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
......@@ -3699,6 +3711,7 @@ inline __device__ void compute_attn_1rowblock_16x64_dim64_prefetch(const Params
for (int i = 0; i < k1_loops; ++i) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x64_64>(gV, sV, i, params.v_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
 
#if 1
#pragma unroll
......@@ -3792,6 +3805,18 @@ inline __device__ void compute_attn_1rowblock_16x64_dim64_prefetch(const Params
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
 
if constexpr(Is_even_K)
{
lds_direct_copy_even_k<0, /*Is_even_MN=*/true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<0, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<2, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<3, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
else
{
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
......@@ -3800,7 +3825,7 @@ inline __device__ void compute_attn_1rowblock_16x64_dim64_prefetch(const Params
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 1, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 2, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 3, params.v_row_stride, params.d);
}
}
 
if (n_masking_steps > 1 && n_block <= n_block_min) {
......@@ -3898,6 +3923,18 @@ inline __device__ void compute_attn_1rowblock_16x64_dim64_prefetch(const Params
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
 
if constexpr(Is_even_K)
{
lds_direct_copy_even_k<0, /*Is_even_MN=*/true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, /*Is_even_MN=*/true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<0, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<1, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<2, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k<3, true, _16x64_64>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
else
{
lds_direct_copy<Is_even_K>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K>(gK, sK, 1, params.k_row_stride, params.d);
......@@ -3906,6 +3943,7 @@ inline __device__ void compute_attn_1rowblock_16x64_dim64_prefetch(const Params
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 1, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 2, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, /*Is_even_MN=*/true, _16x64_64>(gV, sV, 3, params.v_row_stride, params.d);
}
 
}
}
......@@ -4153,11 +4191,11 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
int n_block = n_block_max - 1;
Tensor acc_o = partition_fragment_C(tiled_mma_for_gemm1, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
Tensor acc_o_split = local_tile(acc_o, Shape<Int<8>, Int<1>, Int<kHeadDimV / 32 / 2>>{}, make_coord(0, 0, _));
Tensor acc_o_split = local_tile(acc_o, Shape<Int<8>, Int<kBlockM/64>, Int<kHeadDimV / 32 / 2>>{}, make_coord(0, 0, _));
auto acc_o_temp0 = acc_o_split(_, _, _, 0);
auto acc_o_temp1 = acc_o_split(_, _, _, 1);
clear(acc_o);
constexpr static int K_BUFF_SIZE = 4;
flash::Softmax<size<1>(acc_o)> softmax;
 
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
......@@ -4169,10 +4207,16 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
s_waitcnt<0>();
__syncthreads();
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; ++i) {
for (int i = 0; i < 3; i++) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, i, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
 
#if 1
#pragma unroll
......@@ -4181,42 +4225,73 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
clear(acc_s_ori);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(0, gK, sK, 4, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 4, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<4, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(1, gK, sK, 5, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 5, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<5, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(2, gK, sK, 6, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 6, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<6, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3);
s_barrier();
 
lds_direct_copy<Is_even_K, Is_even_MN>(3, gK, sK, 7, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN>(gK, sK, 7, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<7, K_BUFF_SIZE, Is_even_MN>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 4, 0);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 0, gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 5, 1);
s_barrier();
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 1, gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 6, 2);
s_barrier();
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 2, gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 7, 3);
s_barrier();
......@@ -4272,50 +4347,76 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
}
}
 
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(0, 3, gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 0, gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 1, gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 2, gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, Is_even_MN, _16x256>(1, 3, gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, Is_even_MN, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
if (n_block > n_block_min) {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; ++i) {
for (int i = 0; i < 3; i ++) {
lds_direct_copy<Is_even_K>(gK, sK, i, params.k_row_stride, params.d);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
 
if (n_masking_steps > 1 && n_block <= n_block_min) {
......@@ -4332,42 +4433,74 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
clear(acc_s_ori);
s_barrier();
lds_direct_copy<Is_even_K>(gK, sK, 3, params.k_row_stride, params.d);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true>(gK, sK, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0);
s_barrier();
 
lds_direct_copy<Is_even_K>(0, gK, sK, 4, params.k_row_stride, params.d);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true>(gK, sK, 4, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<4, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1);
s_barrier();
 
lds_direct_copy<Is_even_K>(1, gK, sK, 5, params.k_row_stride, params.d);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true>(gK, sK, 5, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<5, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2);
s_barrier();
 
lds_direct_copy<Is_even_K>(2, gK, sK, 6, params.k_row_stride, params.d);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true>(gK, sK, 6, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<6, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3);
s_barrier();
 
lds_direct_copy<Is_even_K>(3, gK, sK, 7, params.k_row_stride, params.d);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true>(gK, sK, 7, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<7, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 4, 0);
s_barrier();
 
lds_direct_copy<Is_even_K, true, _16x256>(0, 0, gV, sV, 0, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(0, 0, gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, true, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 5, 1);
s_barrier();
 
lds_direct_copy<Is_even_K, true, _16x256>(0, 1, gV, sV, 1, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(0, 1, gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, true, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 6, 2);
s_barrier();
 
lds_direct_copy<Is_even_K, true, _16x256>(0, 2, gV, sV, 2, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(0, 2, gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, true, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 7, 3);
s_barrier();
......@@ -4430,51 +4563,77 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
}
}
lds_direct_copy<Is_even_K, true, _16x256>(0, 3, gV, sV, 3, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(0, 3, gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, true, _16x256, 0>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
lds_direct_copy<Is_even_K, true, _16x256>(1, 0, gV, sV, 0, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(1, 0, gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, true, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
lds_direct_copy<Is_even_K, true, _16x256>(1, 1, gV, sV, 1, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(1, 1, gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, true, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
lds_direct_copy<Is_even_K, true, _16x256>(1, 2, gV, sV, 2, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(1, 2, gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, true, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_o_temp0, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
lds_direct_copy<Is_even_K, true, _16x256>(1, 3, gV, sV, 3, params.v_row_stride, params.d_value);
if constexpr(!Is_even_K) {
lds_direct_copy<Is_even_K, true, _16x256>(1, 3, gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
lds_direct_copy_even_k_dim256<3, K_BUFF_SIZE, true, _16x256, 1>(gV, sV, params.v_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
s_waitcnt<3>();
flash::gemm_k_rs_ds_read_m32x16<0>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<0>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
s_waitcnt<2>();
flash::gemm_k_rs_ds_read_m32x16<1>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<1>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
s_waitcnt<1>();
flash::gemm_k_rs_ds_read_m32x16<2>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<2>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
s_waitcnt<0>();
flash::gemm_k_rs_ds_read_m32x16<3>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
flash::gemm_k_rs_ds_read_m32x16_alt<3>(acc_o_temp1, rP, tOrVt, tOsVt, tiled_mma_for_gemm1, smem_tiled_copy_V, smem_thr_copy_V);
s_barrier();
 
 
if (n_block > n_block_min) {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
if constexpr(!Is_even_K) {
#pragma unroll
for (int i = 0; i < 3; ++i) {
for (int i = 0; i < 3; i ++) {
lds_direct_copy<Is_even_K>(gK, sK, i, params.k_row_stride, params.d);
}
} else {
lds_direct_copy_even_k_dim256<0, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<1, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy_even_k_dim256<2, K_BUFF_SIZE, true>(gK, sK, params.k_row_stride, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
}
#endif
......@@ -4516,19 +4675,32 @@ inline __device__ void compute_attn_1rowblock_16x64_dim256_prefetch(const Params
// asm volatile("v_cmpx_lt_i32 exec, %0, %1":: "v"(row), "v"(qo_len) :);
#pragma unroll
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
if constexpr (Is_even_K) {
col = (laneId / 16) * 2 + ni * 32;
using result_type = cutlass::Array<Element, 2>;
for (int ei = 0; ei < 4; ++ei)
{
result_type res;
res[0] = flash::convert_type<Element>(acc_o(ei, mi, ni));
res[1] = flash::convert_type<Element>(acc_o(ei + 4, mi, ni));
*(result_type*)(&gO(row, col)) = res;
col += 8;
}
} else {
#pragma unroll
for (int ei = 0; ei < size<0>(acc_o); ++ei) {
col = (laneId / 16) + ni * 32 + ei * 4;
// wangaq debug
// printf("bidx:%d bidy:%d bidz:%d tid:%d mi:%d ni:%d ei:%d row:%d col:%d acc_o:%10.4f\n",
// blockIdx.x, blockIdx.y, blockIdx.z, tidx, mi, ni, ei, row, col, acc_o(ei, mi, ni));
if (Is_even_K || col < params.d_value) {
if (col < params.d_value) {
gO(row, col) = flash::convert_type<Element>(acc_o(ei, mi, ni));
}
// else
// gO(row, col) = Element(0.0);
}
}
}
// asm volatile("s_mov_b64 exec, 0xFFFFFFFFFFFFFFFF");
}
}
......@@ -6618,8 +6790,8 @@ inline __device__ void compute_attn_1rowblock_16x64_mla_prefetch_fp8(const Param
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16)*4 + ni * 32;
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(0, mi, ni), 0, acc_o(1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(2, mi, ni), 0, acc_o(3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(0, mi, ni), acc_o(1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(2, mi, ni), acc_o(3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
 
......@@ -6632,8 +6804,8 @@ inline __device__ void compute_attn_1rowblock_16x64_mla_prefetch_fp8(const Param
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(4, mi, ni), 0, acc_o(5, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(6, mi, ni), 0, acc_o(7, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(4, mi, ni), acc_o(5, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(6, mi, ni), acc_o(7, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -8122,16 +8294,16 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block);
// cur_block_table = block_table[n_block - 1];
// cur_block_table = block_table[n_block];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV.data() = gV_data + (offset_k);
gV_tail.data() = gV_tail_data + (offset_k);
}
__builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, false>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
......@@ -8202,7 +8374,6 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
if (!Is_even_MN && Is_need_pad && masking_step == 0) {
__builtin_amdgcn_sched_barrier(0);
flash::gemm_k_rs_pad_ws<Element>(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs_pad_ws<Element>(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 1, binfo.actual_seqlen_k - n_block * kBlockN);
......@@ -8219,9 +8390,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs_pad_ws<Element>(acc_o_tail_ori, rP, tOrV, tSsV_tail, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 3, binfo.actual_seqlen_k - n_block * kBlockN);
S_BARRIER;
__builtin_amdgcn_sched_barrier(0);
} else {
__builtin_amdgcn_sched_barrier(0);
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 1, 1);
......@@ -8238,13 +8407,10 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs(acc_o_tail_ori, rP, tOrV, tSsV_tail, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 3, 3);
S_BARRIER;
__builtin_amdgcn_sched_barrier(0);
}
}
 
if (n_block > 0) {
// gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block - 1);
// cur_block_table = block_table[n_block - 1];
......@@ -8253,10 +8419,10 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV.data() = gV_data + (offset_k);
gV_tail.data() = gV_tail_data + (offset_k);
 
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 1, params.k_row_stride, params.d);
......@@ -8274,7 +8440,6 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
{
__builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 3, params.k_row_stride, params.d);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
......@@ -8296,7 +8461,6 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
__builtin_amdgcn_sched_barrier(0);
}
 
Tensor acc_s = make_tensor(acc_s_ori.data(), convert_layout_acc(acc_s_ori.layout()));
......@@ -8320,7 +8484,6 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
Tensor rP = flash::convert_type<Element>(acc_s);
{
__builtin_amdgcn_sched_barrier(0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV_tail, sV_tail, 2, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV_tail, sV_tail, 3, params.v_row_stride, params.d);
......@@ -8342,22 +8505,21 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::gemm_k_rs(acc_o_tail_ori, rP, tOrV, tSsV_tail, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 3, 3);
S_BARRIER;
__builtin_amdgcn_sched_barrier(0);
}
 
if (n_block > 0) {
int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block - 1);
// cur_block_table = block_table[n_block - 1];
// cur_block_table = block_table[n_block-1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV.data() = gV_data + (offset_k);
gV_tail.data() = gV_tail_data + (offset_k);
#pragma unroll
for (int i = 0; i < kStages; ++i) {
......@@ -8594,17 +8756,16 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV.data() = gV_data + (offset_k);
gV_tail.data() = gV_tail_data + (offset_k);
}
 
float q_descale = params.q_descale_ptr == nullptr ? 1.0f : params.q_descale_ptr[0];
float k_descale = params.k_descale_ptr == nullptr ? 1.0f : params.k_descale_ptr[0];
float v_descale = params.v_descale_ptr == nullptr ? 1.0f : params.v_descale_ptr[0];
 
const float scale_softmax_log2 = params.scale_softmax_log2*q_descale*k_descale;
const float scale_softmax = params.scale_softmax*q_descale*k_descale;
 
......@@ -8623,7 +8784,9 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");;
// flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
Tensor tGrQ_ = recast<uint_byte_t<16>>(tGrQ);
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, (tSsK(_, _, 0)), (tCrK_copy_view(_, _, 0)));
__builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tGrQ_(_, _, 0), tSrK(_, _, 0), acc_s_ori);
S_BARRIER;
lds_direct_copy_fp8<Is_even_K, Is_even_MN, _64x32, 0, false>(gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
......@@ -8766,10 +8929,10 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV.data() = gV_data + (offset_k);
gV_tail.data() = gV_tail_data + (offset_k);
} else {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
......@@ -8864,10 +9027,10 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV.data() = gV_data + (offset_k);
gV_tail.data() = gV_tail_data + (offset_k);
} else {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
......@@ -8919,8 +9082,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(ei, mi, ni), 0, acc_o(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(ei+2, mi, ni), 0, acc_o(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(ei, mi, ni), acc_o(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(ei+2, mi, ni), acc_o(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -8928,8 +9091,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o(ei, mi, ni), 0, acc_o(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o(ei+2, mi, ni), 0, acc_o(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o(ei, mi, ni), acc_o(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o(ei+2, mi, ni), acc_o(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -8949,8 +9112,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o_tail(ei, mi, ni), 0, acc_o_tail(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o_tail(ei+2, mi, ni), 0, acc_o_tail(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o_tail(ei, mi, ni), acc_o_tail(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o_tail(ei+2, mi, ni), acc_o_tail(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -8958,8 +9121,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o_tail(ei, mi, ni), 0, acc_o_tail(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o_tail(ei+2, mi, ni), 0, acc_o_tail(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o_tail(ei, mi, ni), acc_o_tail(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o_tail(ei+2, mi, ni), acc_o_tail(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9544,8 +9707,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o0(ei, mi, ni), 0, acc_o0(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o0(ei+2, mi, ni), 0, acc_o0(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o0(ei, mi, ni), acc_o0(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o0(ei+2, mi, ni), acc_o0(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9553,8 +9716,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o0(ei, mi, ni), 0, acc_o0(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o0(ei+2, mi, ni), 0, acc_o0(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o0(ei, mi, ni), acc_o0(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o0(ei+2, mi, ni), acc_o0(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9574,8 +9737,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o1(ei, mi, ni), 0, acc_o1(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o1(ei+2, mi, ni), 0, acc_o1(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o1(ei, mi, ni), acc_o1(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o1(ei+2, mi, ni), acc_o1(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9583,8 +9746,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o1(ei, mi, ni), 0, acc_o1(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o1(ei+2, mi, ni), 0, acc_o1(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o1(ei, mi, ni), acc_o1(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o1(ei+2, mi, ni), acc_o1(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9604,8 +9767,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o2(ei, mi, ni), 0, acc_o2(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o2(ei+2, mi, ni), 0, acc_o2(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o2(ei, mi, ni), acc_o2(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o2(ei+2, mi, ni), acc_o2(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9613,8 +9776,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o2(ei, mi, ni), 0, acc_o2(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o2(ei+2, mi, ni), 0, acc_o2(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o2(ei, mi, ni), acc_o2(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o2(ei+2, mi, ni), acc_o2(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -9796,22 +9959,23 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV0.data() = gV0_data + (offset_v);
gV1.data() = gV1_data + (offset_v);
gV2.data() = gV2_data + (offset_v);
gV3.data() = gV3_data + (offset_v);
gV0.data() = gV0_data + (offset_k);
gV1.data() = gV1_data + (offset_k);
gV2.data() = gV2_data + (offset_k);
gV3.data() = gV3_data + (offset_k);
 
}
 
float q_descale = params.q_descale_ptr == nullptr ? 1.0f : params.q_descale_ptr[0];
float k_descale = params.k_descale_ptr == nullptr ? 1.0f : params.k_descale_ptr[0];
float v_descale = params.v_descale_ptr == nullptr ? 1.0f : params.v_descale_ptr[0];
// float q_descale = params.q_descale_ptr == nullptr ? 1.0f : params.q_descale_ptr[0];
// float k_descale = params.k_descale_ptr == nullptr ? 1.0f : params.k_descale_ptr[0];
// float v_descale = params.v_descale_ptr == nullptr ? 1.0f : params.v_descale_ptr[0];
 
 
const float scale_softmax_log2 = params.scale_softmax_log2*q_descale*k_descale;
const float scale_softmax = params.scale_softmax*q_descale*k_descale;
// const float scale_softmax_log2 = params.scale_softmax_log2;
// const float scale_softmax = params.scale_softmax;
constexpr float scale_softmax =0.0625;
constexpr float scale_softmax_log2=scale_softmax*1.4426950408889634;
Tensor tCrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
Tensor tCrV_copy_view = smem_thr_copy_V.retile_D(tOrV);
......@@ -10078,12 +10242,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV0.data() = gV0_data + (offset_v);
gV1.data() = gV1_data + (offset_v);
gV2.data() = gV2_data + (offset_v);
gV3.data() = gV3_data + (offset_v);
gV0.data() = gV0_data + (offset_k);
gV1.data() = gV1_data + (offset_k);
gV2.data() = gV2_data + (offset_k);
gV3.data() = gV3_data + (offset_k);
 
} else {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
......@@ -10206,12 +10369,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV0.data() = gV0_data + (offset_v);
gV1.data() = gV1_data + (offset_v);
gV2.data() = gV2_data + (offset_v);
gV3.data() = gV3_data + (offset_v);
gV0.data() = gV0_data + (offset_k);
gV1.data() = gV1_data + (offset_k);
gV2.data() = gV2_data + (offset_k);
gV3.data() = gV3_data + (offset_k);
} else {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV0.data() = gV0.data() + (-int(kBlockN * params.v_row_stride));
......@@ -10226,7 +10388,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// // Epilogue
#if 1
 
Tensor lse = softmax.template normalize_softmax_lse_fp8</*Is_dropout=*/false, Split>(acc_o0, acc_o1, acc_o2, acc_o3, scale_softmax, v_descale);
Tensor lse = softmax.template normalize_softmax_lse_fp8</*Is_dropout=*/false, Split>(acc_o0, acc_o1, acc_o2, acc_o3, scale_softmax, 1.0f);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
......@@ -10264,8 +10426,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o0(ei, mi, ni), 0, acc_o0(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o0(ei+2, mi, ni), 0, acc_o0(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o0(ei, mi, ni), acc_o0(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o0(ei+2, mi, ni), acc_o0(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10273,8 +10435,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o0(ei, mi, ni), 0, acc_o0(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o0(ei+2, mi, ni), 0, acc_o0(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o0(ei, mi, ni), acc_o0(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o0(ei+2, mi, ni), acc_o0(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10294,8 +10456,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o1(ei, mi, ni), 0, acc_o1(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o1(ei+2, mi, ni), 0, acc_o1(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o1(ei, mi, ni), acc_o1(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o1(ei+2, mi, ni), acc_o1(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10303,8 +10465,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o1(ei, mi, ni), 0, acc_o1(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o1(ei+2, mi, ni), 0, acc_o1(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o1(ei, mi, ni), acc_o1(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o1(ei+2, mi, ni), acc_o1(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10324,8 +10486,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o2(ei, mi, ni), 0, acc_o2(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o2(ei+2, mi, ni), 0, acc_o2(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o2(ei, mi, ni), acc_o2(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o2(ei+2, mi, ni), acc_o2(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10333,8 +10495,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o2(ei, mi, ni), 0, acc_o2(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o2(ei+2, mi, ni), 0, acc_o2(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o2(ei, mi, ni), acc_o2(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o2(ei+2, mi, ni), acc_o2(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10354,8 +10516,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o3(ei, mi, ni), 0, acc_o3(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o3(ei+2, mi, ni), 0, acc_o3(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o3(ei, mi, ni), acc_o3(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o3(ei+2, mi, ni), acc_o3(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -10363,8 +10525,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o3(ei, mi, ni), 0, acc_o3(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o3(ei+2, mi, ni), 0, acc_o3(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o3(ei, mi, ni), acc_o3(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o3(ei+2, mi, ni), acc_o3(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -12492,12 +12654,12 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV0.data() = gV0_data + (offset_v);
gV1.data() = gV1_data + (offset_v);
gV2.data() = gV2_data + (offset_v);
gV3.data() = gV3_data + (offset_v);
gV0.data() = gV0_data + (offset_k);
gV1.data() = gV1_data + (offset_k);
gV2.data() = gV2_data + (offset_k);
gV3.data() = gV3_data + (offset_k);
}
 
#pragma unroll
......@@ -12637,12 +12799,12 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV0.data() = gV0_data + (offset_v);
gV1.data() = gV1_data + (offset_v);
gV2.data() = gV2_data + (offset_v);
gV3.data() = gV3_data + (offset_v);
gV0.data() = gV0_data + (offset_k);
gV1.data() = gV1_data + (offset_k);
gV2.data() = gV2_data + (offset_k);
gV3.data() = gV3_data + (offset_k);
} else {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV0.data() = gV0.data() + (-int(kBlockN * params.v_row_stride));
......@@ -12785,12 +12947,12 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
// index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV0.data() = gV0_data + (offset_v);
gV1.data() = gV1_data + (offset_v);
gV2.data() = gV2_data + (offset_v);
gV3.data() = gV3_data + (offset_v);
gV0.data() = gV0_data + (offset_k);
gV1.data() = gV1_data + (offset_k);
gV2.data() = gV2_data + (offset_k);
gV3.data() = gV3_data + (offset_k);
} else {
gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
gV0.data() = gV0.data() + (-int(kBlockN * params.v_row_stride));
......@@ -14504,8 +14666,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (Is_even_K || col < params.d_value) {
if constexpr (std::is_same_v<ElementO, cutlass::bfloat16_t>) {
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(ei, mi, ni), 0, acc_o(ei+1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(ei+2, mi, ni), 0, acc_o(ei+3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(ei, mi, ni), acc_o(ei+1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(ei+2, mi, ni), acc_o(ei+3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -14513,8 +14675,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
gO(row, col + 2) = res1[0];
gO(row, col + 3) = res1[1];
} else {
auto d0 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o(ei, mi, ni), 0, acc_o(ei+1, mi, ni), 0,0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(0, acc_o(ei+2, mi, ni), 0, acc_o(ei+3, mi, ni), 0,0);
auto d0 = __builtin_hcu_cvt_pk_f16_f32(acc_o(ei, mi, ni), acc_o(ei+1, mi, ni), false, 0);
auto d1 = __builtin_hcu_cvt_pk_f16_f32(acc_o(ei+2, mi, ni), acc_o(ei+3, mi, ni), false, 0);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -15104,11 +15266,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_gfx928(
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block);
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
cur_block_table = block_table[n_block - 1];
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
tKgK.data() = tKgK_data + (offset_k);
......@@ -15165,11 +15327,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_gfx928(
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block - 1);
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
cur_block_table = block_table[n_block - 1];
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
tKgK.data() = tKgK_data + (offset_k);
......@@ -15233,11 +15395,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_gfx928(
{
int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block - 1);
// cur_block_table = block_table[n_block - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
cur_block_table = block_table[n_block - 1];
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
tKgK.data() = tKgK_data + (offset_k);
......@@ -16048,7 +16210,7 @@ inline __device__ void compute_attn_16x64_prefetch_fp8(const Params &params) {
index_t binfo_v_offset = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb);
index_t binfo_o_offset = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb);
__syncthreads();
#if (defined(__gfx938__))
#if (defined(__gfx938__)||defined(__gfx92a__))
if constexpr (Kernel_traits::kHeadDim == 192 && Kernel_traits::kHeadDimV == 128){
flash::compute_attn_1rowblock_16x64_mla_prefetch_fp8<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block, binfo_q_offset, binfo_k_offset, binfo_v_offset, binfo_o_offset, binfo);
} else if constexpr (Kernel_traits::kHeadDim == 128){
......@@ -16092,7 +16254,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
const int bidh = Is_causal?blockIdx.x:blockIdx.y;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
if constexpr (Kernel_traits::kHeadDim == 64){
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 128) {
......@@ -16128,7 +16290,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch_fp8(cons
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
const int bidh = Is_causal?blockIdx.x:blockIdx.y;
#if (defined(__gfx938__))
#if (defined(__gfx938__)||defined(__gfx92a__))
if constexpr (Kernel_traits::kHeadDim == 64){
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_fp8_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 128) {
......@@ -16152,7 +16314,7 @@ inline __device__ void compute_attn_unified_16x64_prefetch(const Params &params)
const int num_n_splits = Split ? gridDim.y : 1;
 
 
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
if constexpr (Kernel_traits::kHeadDim == 256) {
flash::compute_attn_1rowblock_unified_16x64_prefetch_dim256<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Use_alibi_sqrt, Use_qq_bias, Use_mm_prefix>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
if constexpr(Is_causal)
......
......@@ -292,6 +292,7 @@ template<typename Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(Kernel_traits::Is_Q_in_regs, "SplitKV implementation must support Is_Q_in_regs");
static_assert(Kernel_traits::Share_Q_K_smem, "SplitKV implementation must support Share_Q_K_smem");
params.num_splits = 1;
// params.num_splits大于1的时候,输出值是float类型,是大于Q的。这里改动的本质原因是q与kv共享lds导致的
const size_t smem_size = params.num_splits > 1 ? std::max(Kernel_traits::kSmemQSize * 2, Kernel_traits::kSmemSize) : Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
......@@ -317,33 +318,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
});
});
});
// printf(" run_flash_splitkv_fwd params.num_splits = %d\n", params.num_splits);
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 32 : (Kernel_traits::kHeadDim % 64 == 0 ? 32 : 32);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}
template<typename Kernel_traits, typename Combine_Kernel_traits, bool Is_causal>
......@@ -364,6 +338,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenMNConst = false;
constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false;
......@@ -433,6 +408,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch_fp8(Flash_fwd_params &par
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
constexpr static bool IsEvenMNConst = false;
constexpr static bool IsEvenKConst = true;
// constexpr static bool Is_local = false;
constexpr static bool Is_softcap = false;
......@@ -580,7 +556,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
bool is_small = (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count);
if (params.is_vllm_kvcache) {
if constexpr(Headdim == 64) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, T, 64>;
if (is_small) {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim64<64, 64, 64, 4, T, 1, 64>;
......@@ -601,7 +577,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
}
}
}else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx936"||get_device_name() == "gfx938") {
if (get_device_name() == "gfx936"||get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
assert(params.knew_ptr == nullptr && params.block_table != nullptr);
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>;
if (is_small) {
......@@ -623,13 +599,13 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
}
}
}else if constexpr(Headdim == 192) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, T, 192>;
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim192<192, 64, 64, 4, T, 3, 192>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
}else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits_dim256<256, 64, 64, 4, T, 3, 256>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
......@@ -704,7 +680,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
// printf("kBlockM = %d, kBlockN = %d", kBlockM, kBlockN);
#ifndef FLASHATTENTION_DISABLE_SPLITKV
if constexpr(Headdim == 64) {
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<64, kBlockM, kBlockN, 4, false, false, TO, 64>;
if (params.seqlen_q < 64) {
......@@ -717,7 +693,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
}
}
}else if constexpr(Headdim == 128) {
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, TO, 128>;
if (params.seqlen_q < 64) {
......@@ -730,7 +706,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
}
}
}else if constexpr(Headdim == 192) {
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<192, kBlockM, kBlockN, 4, false, false, TO, 192>;
if (params.seqlen_q < 64) {
......@@ -743,7 +719,7 @@ void run_mha_fwd_splitkv_dispatch_fp8(Flash_fwd_params &params, cudaStream_t str
}
}
}else if constexpr(Headdim == 256) {
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.is_vllm_kvcache) {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, TO, 256>;
if (params.seqlen_q < 64) {
......@@ -764,7 +740,7 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream)
constexpr static int kBlockM = 64;
constexpr static int kBlockN = Headdim <= 128 ? 64 : (Headdim % 64 == 0 ? 32 : 64);
if constexpr(Headdim == 256) {
if (get_device_name() == "gfx938" || get_device_name() == "gfx936") {
if (get_device_name() == "gfx938" || get_device_name() == "gfx936"|| get_device_name() == "gfx92a") {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits_dim256<256, 64, 64, 4, T, 3, 256>;
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
......@@ -799,7 +775,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
#if 0
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 256, 64, 4, false, /*Share_Q_K_smem_=*/true, T>, Is_dropout, Is_causal>(params, stream);
#else
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
......@@ -817,7 +793,7 @@ template<typename T, bool Is_causal>
void run_mha_fwd_padding_mask_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch_padding_mask<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 128, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
}
else {
......@@ -830,7 +806,7 @@ template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if(params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count){
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim96<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
......@@ -849,7 +825,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
// printf("run_mha_fwd_hdim128\n");
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3, Is_skip_softmax>, Is_dropout, Is_causal>(params, stream);
......@@ -869,7 +845,7 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
// printf("run_mha_fwd_hdim128\n");
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
int mblocks = (params.seqlen_q + 64 - 1) / 64;
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_traits_fp8<Headdim, 64, 64, 4, T,T_out, 3>, Is_dropout, Is_causal>(params, stream);
......@@ -904,7 +880,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim192_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_mla_traits</*Headdim*/192, 128, 64, 4, T, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream);
}
else {
......@@ -919,7 +895,7 @@ void run_mha_fwd_hdim192_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stre
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch_fp8<Flash_fwd_kernel_16x64_prefetch_mla_traits_fp8</*Headdim*/192, 128, 64, 4, T,T_out, 3, /*HeaddimV*/128>, Is_dropout, Is_causal>(params, stream);
}
else {
......@@ -941,8 +917,8 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// constexpr static int Is_dropout = false;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim256<Headdim, 128, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
}
......@@ -954,7 +930,7 @@ void run_mha_fwd_hdim512(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 512;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// constexpr static int Is_dropout = false;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
......
......@@ -1854,8 +1854,8 @@ inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16)*4 + ni * 32;
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(0, mi, ni), 0, acc_o(1, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(2, mi, ni), 0, acc_o(3, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(0, mi, ni), acc_o(1, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(2, mi, ni), acc_o(3, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
......@@ -1868,8 +1868,8 @@ inline __device__ void sparse_attn_1rowblock_sla_fp8(const Params &params, const
}
{
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(4, mi, ni), 0, acc_o(5, mi, ni), 0);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(6, mi, ni), 0, acc_o(7, mi, ni), 0);
auto d0 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(4, mi, ni), acc_o(5, mi, ni), false);
auto d1 = __builtin_hcu_cvt_pk_bf16_f32(acc_o(6, mi, ni), acc_o(7, mi, ni), false);
auto res0 = reinterpret_cast<result_type const &>(d0);
auto res1 = reinterpret_cast<result_type const &>(d1);
gO(row, col) = res0[0];
......@@ -1934,7 +1934,7 @@ inline __device__ void compute_sparse_attn_sla_fp8(const Params &params) {
const int bidb = blockIdx.z;
// The block index for the head.
const int bidh = blockIdx.y;
#if defined(__gfx938__)
#if defined(__gfx938__) ||defined(__gfx92a__)
flash::sparse_attn_1rowblock_sla_fp8<Kernel_traits, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
#endif
}
......
......@@ -128,7 +128,7 @@ void run_mha_fwd_sparse_hdim64(Flash_fwd_params_sparse &params, cudaStream_t str
template<typename T>
void run_mha_fwd_sparse_sla_hdim64(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.seqlen_q <= 2048)
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>>(params, stream);
else
......@@ -155,7 +155,7 @@ void run_mha_fwd_sparse_hdim128(Flash_fwd_params_sparse &params, cudaStream_t st
template<typename T>
void run_mha_fwd_sparse_sla_hdim128(Flash_fwd_params_sparse &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
if (params.seqlen_q <= 2048)
run_flash_sparse_sla_fwd<Flash_fwd_kernel_16x64_prefetch_traits<Headdim, 64, 64, 4, T, 3>>(params, stream);
else
......@@ -168,7 +168,7 @@ void run_mha_fwd_sparse_sla_hdim128_fp8(Flash_fwd_params_sparse &params, cudaStr
constexpr static int Headdim = 128;
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
if (get_device_name() == "gfx938") {
if (get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
// int num_blocks_64 = params.h * params.b * ((params.seqlen_q + 64 - 1) / 64);//3
// int num_blocks_128 = params.h * params.b * ((params.seqlen_q + 128 - 1) / 128);//2
// if ((num_blocks_64 <= sm_count || (num_blocks_128 / sm_count == 1 && num_blocks_128 % sm_count > 1 && (num_blocks_64 + sm_count - 1) / sm_count <= 3) || force_blockm64) && !force_blockm128) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment