Commit 34e67b1e authored by zhangshao's avatar zhangshao
Browse files

first commit

parents
Pipeline #3582 failed with stages
in 0 seconds
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_kv_fp8<cutlass::bfloat16_t, cutlass::float_e5m2_t,64, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_kv_fp8<cutlass::bfloat16_t, cutlass::float_e5m2_t,64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_kv_fp8<cutlass::half_t, cutlass::float_e5m2_t,64, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch_kv_fp8<cutlass::half_t, cutlass::float_e5m2_t,64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);
#pragma once
#include "flash.h"
// namespace FLASH_NAMESPACE {
struct Flash_fwd_params_sparse : public Flash_fwd_params {
// For sparse attention
const int* block_count;
const int* block_offset;
const int* column_count;
const int* column_index;
int NUM_ROWS;
int NNZ_S;
int NNZ_V;
// Dynamic PV skip optimization parameters
float pv_threshold; // Threshold for skipping P@V computation (default: 50.0, matching SpargeAttn)
bool enable_dynamic_skip; // Whether to enable dynamic skip (default: true)
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_sparse_(Flash_fwd_params_sparse &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_sparse_sla_(Flash_fwd_params_sparse &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_sparse_sla_fp8_(Flash_fwd_params_sparse &params, cudaStream_t stream);
// } // namespace FLASH_NAMESPACE
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash.h"
#include <cmath>
#include <limits>
#include <type_traits>
namespace {
constexpr int kWaveSize = 64;
constexpr int kPairsPerBlock = 4;
constexpr int kThreadsPerBlock = kWaveSize * kPairsPerBlock;
template<typename storage_t>
static __device__ inline void from_float(storage_t &out, float value) {
if constexpr (std::is_same_v<storage_t, _Float16>) {
out = static_cast<_Float16>(value);
} else {
union {
uint32_t int32;
float fp32;
} u = {0};
u.fp32 = value;
uint32_t bits = u.int32;
bits += 0x8000;
out = static_cast<uint16_t>(bits >> 16);
}
}
template<typename storage_t>
static __device__ inline float to_float(storage_t value) {
if constexpr (std::is_same_v<storage_t, _Float16>) {
return static_cast<float>(value);
} else {
union {
uint32_t int32;
float fp32;
} u = {static_cast<uint32_t>(value) << 16};
return u.fp32;
}
}
static __device__ inline float wave_allreduce_sum(float value) {
for (int mask = kWaveSize / 2; mask >= 1; mask /= 2) {
value += __shfl_xor(value, mask);
}
return value;
}
template<typename storage_t>
__global__ __launch_bounds__(kThreadsPerBlock)
void flash_varlen_fwd_tiny_hdim64_kernel(const Flash_fwd_params params) {
const int wave_idx = threadIdx.x / kWaveSize;
const int lane = threadIdx.x % kWaveSize;
const int pair_idx = blockIdx.x * kPairsPerBlock + wave_idx;
const int total_pairs = params.b * params.h;
if (pair_idx >= total_pairs) {
return;
}
const int batch_idx = pair_idx / params.h;
const int q_head_idx = pair_idx % params.h;
const int kv_head_idx = q_head_idx / params.h_h_k_ratio;
const int q_start = params.cu_seqlens_q[batch_idx];
const int q_end = params.cu_seqlens_q[batch_idx + 1];
const int q_len = q_end - q_start;
if (q_len <= 0) {
return;
}
const int k_start = params.cu_seqlens_k[batch_idx];
int k_len = params.is_seqlens_k_cumulative
? params.cu_seqlens_k[batch_idx + 1] - k_start
: params.cu_seqlens_k[batch_idx];
if (params.seqused_k != nullptr) {
k_len = params.seqused_k[batch_idx];
}
k_len = k_len < 0 ? 0 : k_len;
auto *out_ptr = reinterpret_cast<storage_t *>(params.o_ptr);
auto *lse_ptr = reinterpret_cast<float *>(params.softmax_lse_ptr);
const auto *q_ptr = reinterpret_cast<const storage_t *>(params.q_ptr);
const auto *k_ptr = reinterpret_cast<const storage_t *>(params.k_ptr);
const auto *v_ptr = reinterpret_cast<const storage_t *>(params.v_ptr);
if (k_len == 0) {
storage_t zero_value;
from_float(zero_value, 0.f);
const int64_t lse_base = static_cast<int64_t>(q_head_idx) * params.total_q;
for (int row = 0; row < 4; ++row) {
if (row >= q_len) {
break;
}
const int64_t out_offset =
static_cast<int64_t>(q_start + row) * params.o_row_stride
+ static_cast<int64_t>(q_head_idx) * params.o_head_stride
+ lane;
out_ptr[out_offset] = zero_value;
if (lane == 0) {
lse_ptr[lse_base + q_start + row] = std::numeric_limits<float>::infinity();
}
}
return;
}
float q_reg[4] = {0.f, 0.f, 0.f, 0.f};
float k_reg[4] = {0.f, 0.f, 0.f, 0.f};
float v_reg[4] = {0.f, 0.f, 0.f, 0.f};
#pragma unroll
for (int row = 0; row < 4; ++row) {
if (row < q_len) {
const int64_t q_offset =
static_cast<int64_t>(q_start + row) * params.q_row_stride
+ static_cast<int64_t>(q_head_idx) * params.q_head_stride
+ lane;
q_reg[row] = to_float(q_ptr[q_offset]);
}
if (row < k_len) {
const int64_t k_offset =
static_cast<int64_t>(k_start + row) * params.k_row_stride
+ static_cast<int64_t>(kv_head_idx) * params.k_head_stride
+ lane;
const int64_t v_offset =
static_cast<int64_t>(k_start + row) * params.v_row_stride
+ static_cast<int64_t>(kv_head_idx) * params.v_head_stride
+ lane;
k_reg[row] = to_float(k_ptr[k_offset]);
v_reg[row] = to_float(v_ptr[v_offset]);
}
}
float probs[4][4];
float row_lse[4];
#pragma unroll
for (int row = 0; row < 4; ++row) {
#pragma unroll
for (int col = 0; col < 4; ++col) {
float score = q_reg[row] * k_reg[col] * params.scale_softmax;
score = wave_allreduce_sum(score);
probs[row][col] = score;
}
}
#pragma unroll
for (int row = 0; row < 4; ++row) {
const int causal_limit = row + k_len - q_len;
float row_max = -std::numeric_limits<float>::infinity();
bool has_valid_key = false;
#pragma unroll
for (int col = 0; col < 4; ++col) {
const bool valid = row < q_len && col < k_len && col <= causal_limit;
if (!valid) {
probs[row][col] = -std::numeric_limits<float>::infinity();
continue;
}
has_valid_key = true;
row_max = fmaxf(row_max, probs[row][col]);
}
if (!has_valid_key) {
row_lse[row] = std::numeric_limits<float>::infinity();
#pragma unroll
for (int col = 0; col < 4; ++col) {
probs[row][col] = 0.f;
}
continue;
}
float row_sum = 0.f;
#pragma unroll
for (int col = 0; col < 4; ++col) {
const float score = probs[row][col];
if (score == -std::numeric_limits<float>::infinity()) {
probs[row][col] = 0.f;
continue;
}
const float prob = expf(score - row_max);
probs[row][col] = prob;
row_sum += prob;
}
const float inv_row_sum = 1.f / row_sum;
row_lse[row] = row_max + logf(row_sum);
#pragma unroll
for (int col = 0; col < 4; ++col) {
probs[row][col] *= inv_row_sum;
}
}
float out_accum[4] = {0.f, 0.f, 0.f, 0.f};
#pragma unroll
for (int row = 0; row < 4; ++row) {
#pragma unroll
for (int col = 0; col < 4; ++col) {
out_accum[row] += probs[row][col] * v_reg[col];
}
}
const int64_t lse_base = static_cast<int64_t>(q_head_idx) * params.total_q;
#pragma unroll
for (int row = 0; row < 4; ++row) {
if (row >= q_len) {
break;
}
storage_t out_value;
from_float(out_value, out_accum[row]);
const int64_t out_offset =
static_cast<int64_t>(q_start + row) * params.o_row_stride
+ static_cast<int64_t>(q_head_idx) * params.o_head_stride
+ lane;
out_ptr[out_offset] = out_value;
if (lane == 0) {
lse_ptr[lse_base + q_start + row] = row_lse[row];
}
}
}
template<typename storage_t>
void run_mha_varlen_tiny_fwd_dim64_(Flash_fwd_params &params, cudaStream_t stream) {
const int total_pairs = params.b * params.h;
if (total_pairs == 0) {
return;
}
const dim3 grid((total_pairs + kPairsPerBlock - 1) / kPairsPerBlock);
const dim3 block(kThreadsPerBlock);
flash_varlen_fwd_tiny_hdim64_kernel<storage_t><<<grid, block, 0, stream>>>(params);
}
} // namespace
void run_mha_varlen_tiny_fwd_dim64(Flash_fwd_params &params, cudaStream_t stream) {
if (params.is_bf16) {
run_mha_varlen_tiny_fwd_dim64_<uint16_t>(params, stream);
} else {
run_mha_varlen_tiny_fwd_dim64_<_Float16>(params, stream);
}
}
# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
# This file is run to generate the kernel instantiations for the flash_attn kernels
# They are written to several files in order to speed up compilation
import argparse
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
DTYPE_MAP = {
"fp16": "cutlass::half_t",
"bf16": "cutlass::bfloat16_t",
}
SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
IS_CAUSAL = ["false", "true"]
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
"""
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}}
"""
@dataclass
class Kernel:
sm: int
dtype: str
head_dim: int
is_causal: bool
direction: str
@property
def template(self) -> str:
if self.direction == "fwd":
return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
)
else:
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
@property
def filename(self) -> str:
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
def get_all_kernels() -> List[Kernel]:
for direction in ["fwd", "fwd_split"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
prelude = """// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"\n
"""
(autogen_dir / kernel.filename).write_text(prelude + kernel.template)
def main(output_dir: Optional[str]) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir)
for kernel in get_all_kernels():
write_kernel(kernel, output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate_kernels",
description="Generate the flash_attention kernels template instantiations",
)
# Set an optional output directory
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="Where to generate the kernels "
" will default to the current directory ",
)
args = parser.parse_args()
main(args.output_dir)
This source diff could not be displayed because it is too large. You can view the blob instead.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
namespace flash {
using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
// 计算块内线程位置
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + lane_id / 16;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 4;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_continuous(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
// 计算块内线程位置
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + lane_id / 16;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 4;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
const int row_idx = row_idx_base;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local_continuous(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
const int row_idx = row_idx_base;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal_continuous(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local_continuous</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_trans(Tensor<Engine, Layout> &tensor, const int max_seqlen_q,
const int col_idx_offset_ = 0) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
// 计算块内线程位置
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
if (col_idx >= max_seqlen_q) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local_trans(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
// static_assert(Layout::rank == 2, "Only support 2D Tensor");
// const int lane_id = threadIdx.x % 64;
// const int col_idx_offset = col_idx_offset_ + lane_id / 16;
// const int stride_between_each_repeat = 16;
// const int stride_between_each_thread = 4;
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 1;
if constexpr (HasWSLeft) {
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
const int row_idx_limit_up = std::max(0, col_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int row_idx_limit_down = std::min(max_seqlen_k, col_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
const int row_idx = row_idx_base;
// int tidx = threadIdx.x;
// if (tidx < 64)
// {
// printf("col_idx = %d row_idx_limit_up = %d row_idx_limit_down = %d\n", col_idx, row_idx_limit_up, row_idx_limit_down);
// }
if (row_idx < row_idx_limit_up || row_idx >= row_idx_limit_down) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
else {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
const int row_idx = row_idx_base;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_q - max_seqlen_k - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_q - max_seqlen_k + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j * stride_between_each_thread;
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("tid = %d col_idx_limit_left = %d col_idx_limit_right = %d col_idx = %d row_idx = %d max_seqlen_k = %d max_seqlen_q = %d\n", threadIdx.x, col_idx_limit_left, col_idx_limit_right, col_idx, row_idx,
// max_seqlen_k, max_seqlen_q);
// }
if (col_idx + 1 < col_idx_limit_left) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
// if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
// tensor(mi, make_coord(j, nj)) = -INFINITY;
// }
}
}
}
}
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal_trans(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local_trans</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <bool Is_causal, bool Is_local, bool Has_alibi>
struct Mask {
const int max_seqlen_k, max_seqlen_q;
const int window_size_left, window_size_right;
const float alibi_slope;
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
const int window_size_left, const int window_size_right,
const float alibi_slope=0.f)
: max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q)
, window_size_left(window_size_left)
, window_size_right(window_size_right)
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
};
// Causal_mask: whether this particular iteration needs causal masking
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if constexpr (Need_masking) {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
/*
查看acc的指令格式
*/
// 0_15 = 0 16_31 = 1 32_47=2 48~63=4
const int lane_id = threadIdx.x & 63;
const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 4;
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// 沿着N方向重复,间隔为16
const int col_idx_base = col_idx_offset + (nj << 4);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
/*
每个线程4个元素,其间隔为4
因为格式是
t0 t16 t32 t48 | t0 t16 t32 t48
*/
const int col_idx = col_idx_base + (j << 2) ;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + (nj << 4);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
// t0的第0个元素与t0的第1个元素间隔4
const int col_idx = col_idx_base + (j << 2);
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
} else {
tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
// else {
// if constexpr (!Has_alibi && !Is_local) {
// return;
// }
// }
}
if constexpr (Is_local) {
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
// #if 1
// if (cute::thread0())
// {
// printf("in mask Is_even_MN = %d\n", Is_even_MN);
// }
// #enfif
// if causal情况下mn也不是整数
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// Causal and Local already handles MN masking
if (col_idx >= max_seqlen_k) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
// #pragma unroll
// for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
// const int row_idx_base = row_idx_offset + mi * warp_row_stride;
// #pragma unroll
// for (int i = 0; i < size<0, 0>(tensor); ++i) {
// const int row_idx = row_idx_base + i * 8;
// const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
// const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
// #pragma unroll
// for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// const int col_idx_base = col_idx_offset + nj * 8;
// #pragma unroll
// for (int j = 0; j < size<1, 0>(tensor); ++j) {
// const int col_idx = col_idx_base + j;
// if constexpr (Has_alibi) {
// if constexpr (Is_causal) {
// tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
// } else {
// tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
// }
// }
// if constexpr (Causal_mask) {
// if (col_idx >= col_idx_limit_right) {
// tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
// }
// }
// if constexpr (Is_local) {
// if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
// tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
// }
// }
// if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// // Causal and Local already handles MN masking
// if (col_idx >= max_seqlen_k) {
// tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
// }
// }
// }
// }
// }
// }
}
}
};
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_continuous(Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if constexpr (Need_masking) {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
/*
查看acc的指令格式
*/
// 0_15 = 0 16_31 = 4 32_47=8 48~63=12
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + ((lane_id >> 4) << 2);
const int stride_between_each_repeat = 16;
const int stride_between_each_thread = 4;
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// 沿着N方向重复,间隔为16
const int col_idx_base = col_idx_offset + (nj << 4);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
/* 每个线程4个元素,其间隔为1
t0 t1 t2 t3 | t4 t5 t6 t7 */
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + (nj << 4);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
// t0的第0个元素与t0的第1个元素间隔1
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
} else {
tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
// else {
// if constexpr (!Has_alibi && !Is_local) {
// return;
// }
// }
}
if constexpr (Is_local) {
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
// #if 1
// if (cute::thread0())
// {
// printf("in mask Is_even_MN = %d\n", Is_even_MN);
// }
// #enfif
// if causal情况下mn也不是整数
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// Causal and Local already handles MN masking
if (col_idx >= max_seqlen_k) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
}
};
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_continuous_fp8(Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 8, "First dimension must be 8");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if constexpr (Need_masking) {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
/*
查看acc的指令格式
*/
// 0_15 = 0 16_31 = 4 32_47=8 48~63=12
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + ((lane_id >> 4) << 3);
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {//2
// 沿着N方向重复,间隔为16
const int col_idx_base = col_idx_offset + (nj << 5);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {//8
/* 每个线程8个元素,其间隔为1
t0 t1 t2 t3 | t4 t5 t6 t7 */
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {//2
const int col_idx_base = col_idx_offset + (nj << 5);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {//8
// t0的第0个元素与t0的第1个元素间隔1
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
} else {
tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
// else {
// if constexpr (!Has_alibi && !Is_local) {
// return;
// }
// }
}
if constexpr (Is_local) {
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
// #if 1
// if (cute::thread0())
// {
// printf("in mask Is_even_MN = %d\n", Is_even_MN);
// }
// #enfif
// if causal情况下mn也不是整数
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// Causal and Local already handles MN masking
if (col_idx >= max_seqlen_k) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
}
};
template <bool Causal_mask=false, bool Is_even_MN=true,
bool Use_alibi_sqrt=false, bool Use_qq_bias=false, bool Use_mm_prefix=false,
typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_continuous_unified(
Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride,
const int context_len,
const void * __restrict__ qq_bias_ptr = nullptr,
const int qq_bias_stride_0 = 0,
const int * __restrict__ mm_prefix_range_ptr = nullptr,
const int max_mm_ranges = 0,
const int bidb = 0,
const float softmax_scale = 1.0f
) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local
|| !Is_even_MN || Use_qq_bias || Use_mm_prefix;
if constexpr (!Need_masking) return;
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local
&& !Causal_mask && !Use_mm_prefix && !Use_qq_bias
&& !(Has_alibi && Use_alibi_sqrt); // 新增
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
const int lane_id = threadIdx.x % 64;
const int col_idx_offset = col_idx_offset_ + ((lane_id >> 4) << 2);
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + (nj << 4);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
if constexpr (Has_alibi) {
// causal alibi:slope * col_idx
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int row_idx = row_idx_offset + mi * warp_row_stride;
const int query_abs_pos = row_idx + (max_seqlen_k - max_seqlen_q);
const int col_idx_limit_left = std::max(0,
row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k,
row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + (nj << 4);
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
bool is_masked = false;
if constexpr (Causal_mask) {
is_masked |= (col_idx >= col_idx_limit_right);
}
if constexpr (Is_local) {
is_masked |= (col_idx >= col_idx_limit_right
|| col_idx < col_idx_limit_left);
}
if constexpr (!Is_even_MN) {
if constexpr (!Causal_mask && !Is_local) {
// causal/local 已经处理了边界,这里只处理纯边界情况
is_masked |= (col_idx >= max_seqlen_k);
}
}
if constexpr (Use_mm_prefix) {
bool in_bidirectional = false;
#pragma unroll
for (int i = 0; i < max_mm_ranges; ++i) {
const int range_start = mm_prefix_range_ptr[
bidb * max_mm_ranges * 2 + i * 2];
const int range_end = mm_prefix_range_ptr[
bidb * max_mm_ranges * 2 + i * 2 + 1];
const bool is_valid = (range_start < range_end);
const bool q_in_range = is_valid
&& (query_abs_pos >= range_start)
&& (query_abs_pos <= range_end);
const bool k_in_range = is_valid
&& (col_idx >= range_start)
&& (col_idx <= range_end);
in_bidirectional |= (q_in_range && k_in_range);
}
if (in_bidirectional) is_masked = false;
}
// 写入 -inf 并跳过后续计算
if (is_masked) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
continue;
}
if constexpr (Has_alibi) {
if constexpr (Use_alibi_sqrt) {
// 对应 triton:-sqrt(max(0, query_abs_pos - seq_offset))
const float rel = float(query_abs_pos - col_idx);
const float alibi_offset = rel >= 0.f ? -sqrtf(rel) : 0.f;
tensor(mi, make_coord(j, nj)) += alibi_slope * alibi_offset;
} else {
// 对应 triton:alibi_offset = seq_offset - context_len
tensor(mi, make_coord(j, nj)) +=
alibi_slope * (col_idx - context_len);
}
}
if constexpr (Use_qq_bias) {
const int query_pos = row_idx;
const int key_rel_pos = col_idx - context_len;
if (query_pos >= 0 && query_pos < max_seqlen_q &&
key_rel_pos >= 0 && key_rel_pos < qq_bias_stride_0) {
float bias_val = reinterpret_cast<const float*>(qq_bias_ptr)
[query_pos * qq_bias_stride_0 + key_rel_pos];
tensor(mi, make_coord(j, nj)) += bias_val / softmax_scale;
}
}
}
}
}
}
};
};
} // namespace flash
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#define WARP_SIZE 64
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Input validation macros (consistent with flash_api.cpp and flash_api_sparse.cpp)
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
static constexpr int LDS_size = 65536;
static constexpr int max_tmp_offset=4000000;
static constexpr int signal_tmp_offset=8000000;
static constexpr int streamk_max_block=160*8;
static constexpr int out_tmp_offset=signal_tmp_offset+streamk_max_block*2;
// static constexpr int PARTITION_SIZE=512;
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template<typename scalar_t>
static __device__ inline void from_float(scalar_t &out ,float f){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
out=f;
}
else{
uint32_t u = *(uint32_t*)(&f);
// u += 0x7fff + ((u >> 16) & 1);
u += 0x8000;
out = u>>16;
}
}
template<typename scalar_t>
static __device__ inline float to_float(scalar_t in){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
return in;
}
else{
union{
uint32_t int32;
float fp32;
} u = {uint32_t(in) << 16};
return u.fp32;
}
}
inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) )
return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
#endif
}
template<typename scalar_t,bool is_e4m3>
__forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
if constexpr(!is_e4m3){
uf16 u16;
u16.as_bits = (uint16_t)input << 8;
if constexpr(std::is_same<scalar_t, _Float16>::value){
return u16.as_value;
}
else{
uf32 u32;
u32.as_value = (float)u16.as_value;
return u32.as_bits>>16;
}
}
else{
uf32 u32;
u32.as_value = uint82float(input);
if constexpr(std::is_same<scalar_t, _Float16>::value){
return (_Float16)(u32.as_value);
}
else{
return (uint16_t)(u32.as_bits >> 16);
}
}
}
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define Input_Type_SWITCH(SRC_DTYPE, ...) \
[&] { \
if (SRC_DTYPE == at::ScalarType::Half) { \
using scalar_t=_Float16; \
return __VA_ARGS__(); \
}else { \
using scalar_t=uint16_t; \
return __VA_ARGS__(); \
} \
}()
#define Cache_Type_SWITCH(scalar_t,dtype, ...) \
[&] { \
if(dtype==torch::kFloat8_e5m2){ \
using cache_t=uint8_t; \
constexpr bool is_e4m3=false; \
return __VA_ARGS__(); \
}else if(dtype==torch::kFloat8_e4m3fn){ \
using cache_t=uint8_t; \
constexpr bool is_e4m3=true; \
return __VA_ARGS__(); \
}else { \
using cache_t=scalar_t; \
constexpr bool is_e4m3=false; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \
}else if (reusekv==36){ \
constexpr static int REUSE_KV_TIMES = 36; \
return __VA_ARGS__(); \
}else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \
return __VA_ARGS__(); \
}else if (reusekv==24){ \
constexpr static int REUSE_KV_TIMES = 24; \
return __VA_ARGS__(); \
}else if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else { \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(headsize,...) \
[&] { \
if (headsize==64){ \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
}else if(headsize==128){ \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
}else if(headsize==192){ \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
}else { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
}()
static std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
static const std::string device_name=get_device_name();
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
}
return 0;
}
static const int PA_USE_STREAMK = get_env_("PA_USE_STREAMK");
static const int PA_MAX_BLOCKS = get_env_("PA_MAX_BLOCKS");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE");
static const int PA_GFX938 = get_env_("PA_GFX938");
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
template<int vec>
struct half4vec{
half4_t data[vec];
};
using half4x2 = half4vec<2>;
template<int vec>
struct uint8x4vec{
uint8x4_t data[vec];
};
using uint8x4x2 = uint8x4vec<2>;
using uint8x4x4 = uint8x4vec<4>;
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
int warp = __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE);
int lane = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor(sum, mask);
}
if (lane == 0) {
red_smem[warp] = sum;
}
__syncthreads();
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor(sum, mask);
}
return __shfl(sum, 0);
}
template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
#if (defined(__gfx938__) )
if constexpr (is_half){reg_c=__builtin_hcu_mmac_f32_16x16x16_f16_lit_lts(reg_a,reg_b,reg_c,false,false);}
else{
reg_c=__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c,false,false);
}
#else
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
}
#endif
}
template <typename scalar_t, typename cache_t,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, int REUSE_KV_TIMES> // Zero means no partitioning.
__launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size]
scalar_t* __restrict__ out_tmp, // [num_seqs, num_heads, max_num_partitions,head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads,
const int num_kv_heads, // [num_heads]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,const int kv_block_stride,
const float* k_scale_ptr, const float* v_scale_ptr,int max_num_partitions,int PARTITION_SIZE,
const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if(num_partitions<=partition_idx)return ;
constexpr bool is_half = std::is_same<scalar_t, _Float16>::value;
constexpr bool is_fp8 = std::is_same<cache_t, uint8_t>::value;
constexpr float scale = (HEAD_SIZE==64?0.125f:(HEAD_SIZE==128? 0.0883883476f:(HEAD_SIZE==192?0.0721687836f:0.0625f)))*1.4426950408889634;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16;
const int rows = lane/16;
float k_scale=scale;
float v_scale=1.0;
if(k_scale_ptr!=nullptr){
k_scale*=(*k_scale_ptr);
v_scale=*v_scale_ptr;
}
const int num_queries_per_kv = num_heads / num_kv_heads;
const int head_idx=blockIdx.x*num_queries_per_kv;
const int kv_head_idx = blockIdx.x;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1;
extern __shared__ char shared_mem[];
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
float* s_max = reinterpret_cast<float*>(shared_mem + sizeof(scalar_t)*num_queries_per_kv*PARTITION_SIZE);
float* s_logit = s_max + num_queries_per_kv * NUM_WARPS;
float* max_out = s_logit+NUM_WARPS;
float* expsum_out = max_out+num_queries_per_kv;
// ★ Attention Sinks: load s_aux to shared memory ★
__shared__ scalar_t smem_s_aux[64];
if (s_aux_ptr != nullptr) {
if (thread_idx < num_heads) {
smem_s_aux[thread_idx] = s_aux_ptr[thread_idx];
}
__syncthreads();
}
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float alibi_slope[reuse_group]={0.f};
if (has_abili){
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<num_queries_per_kv) alibi_slope[i]=alibi_slopes[head_idx+reuse_kv_idx]*1.4426950408889634;
}
}
float qk_max[reuse_group];
for(int i=0;i<reuse_group;i++){
qk_max[i]=-FLT_MAX;
}
half4x2 q_vec[Mloop][HEAD_SIZE/32];
half4x2 q_zero;
q_zero.data[0]={0,0,0,0};
q_zero.data[1]={0,0,0,0};
scalar_t* s_q = reinterpret_cast<scalar_t*>(shared_mem);
for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+i);
}
__syncthreads();
for(int m=0;m<Mloop;m++){
for(int i=0;i<HEAD_SIZE/32;i++){
int head_idx_=rowid+16*m;
if(head_idx_<num_queries_per_kv)q_vec[m][i]=*reinterpret_cast<const half4x2*>(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*8);
else q_vec[m][i]=q_zero;
}
}
__syncthreads();
const int start_block_idx = partition_idx * PARTITION_SIZE / BLOCK_SIZE;
const int end_block_idx =MIN(start_block_idx + PARTITION_SIZE / BLOCK_SIZE, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
//comput q*k
{
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
#pragma unroll
for(int b=0;b<BLOCK_SIZE;b+=16){
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride + b*HEAD_SIZE;
float4_t qk_vec[Mloop];
for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0};
}
#pragma unroll
for(int i=0;i<HEAD_SIZE/32;i++){
half4x2 k_vec;
if constexpr(is_fp8){
uint8x4x2 k_vec_u8=*reinterpret_cast<const uint8x4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
scalar_t *p1=(scalar_t*)&k_vec;
uint8_t *p2=(uint8_t*)&k_vec_u8;
for(int ii=0;ii<8;ii++){
p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]);
}
}
else{
k_vec=*reinterpret_cast<const half4x2*>(k_ptr+i*32+rowid*HEAD_SIZE+rows*8);
}
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(k_vec.data[0],q_vec[m][i].data[0],qk_vec[m]);
builtin_amdgcn_mmac<is_half>(k_vec.data[1],q_vec[m][i].data[1],qk_vec[m]);
}
}
#pragma unroll
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
int m = reuse_kv_idx/16;
int ii = i%4;
if(reuse_kv_idx<num_queries_per_kv){
qk_vec[m][ii]*=k_scale;
const int token_idx = block_idx * BLOCK_SIZE+rowid + b;
if (has_abili){
float alibi=alibi_slope[i] * (token_idx - seq_len + 1);
qk_vec[m][ii] += alibi;
}
if(token_idx >= seq_len) {
int seq_len_pad=DIVIDE_ROUND_UP(seq_len,8)*8;
if(token_idx<seq_len_pad) from_float(logits[PARTITION_SIZE*reuse_kv_idx+token_idx - start_token_idx],-INFINITY);
else logits[PARTITION_SIZE*reuse_kv_idx+token_idx - start_token_idx]=0;
}
else{
scalar_t temp;
if (mtp>1){
int casual = mtp - reuse_kv_idx * mtp / num_heads ;
if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY;
}
from_float(temp,qk_vec[m][ii]);
logits[PARTITION_SIZE*reuse_kv_idx+token_idx- start_token_idx]=temp;
qk_max[i] = fmaxf(qk_max[i], to_float(temp));
// if(partition_idx==0)printf("tid=%d,tokenid=%d,reuse_kv_idx=%d,m=%d,ii=%d,qk=%f\n",thread_idx,token_idx,reuse_kv_idx,m,i,qk_vec[m][ii]);
}
}
}
}
}
}
// compute max
#pragma unroll
for (int mask = 8; mask >= 1; mask /= 2) {
#pragma unroll
for(int r=0;r<reuse_group;r++){
qk_max[r]=fmaxf(qk_max[r],__shfl_xor(qk_max[r],mask));
}
}
#pragma unroll
for(int r=0;r<reuse_group;r++){
if(rowid==0&&r*4+rows<num_queries_per_kv){
s_max[(r*4+rows)*NUM_WARPS+warp_idx] = qk_max[r];
}
}
__syncthreads();
if(PARTITION_SIZE==256){
for(int lineid = warp_idx;lineid<REUSE_KV_TIMES/2;lineid+=NUM_WARPS){
int half_lane = lane%32;
int which_half = lane/32;
int real_line=lineid*2+which_half;
if(real_line<num_queries_per_kv){
float qk_max_tmp;
float exp_sum=0;
if(half_lane==0){
int smax_offset = real_line*4;
qk_max_tmp=s_max[smax_offset];
for(int i=1;i<4;i++){
qk_max_tmp=fmaxf(qk_max_tmp,s_max[smax_offset+i]);
}
}
qk_max_tmp=__shfl(qk_max_tmp,which_half*32);
int seq_len_pad = DIVIDE_ROUND_UP(num_tokens,8);
using f16x8_t = __attribute__( (__vector_size__(8 * sizeof(scalar_t)) )) scalar_t;
using f32x8_t = __attribute__( (__vector_size__(8 * sizeof(float)) )) float;
float sink_contrib = 0.f;
if (s_aux_ptr != nullptr && partition_idx == 0) {
float s_aux_val = to_float(smem_s_aux[head_idx+real_line]); // Convert scalar_t (fp16/bf16) to float
sink_contrib = __builtin_amdgcn_exp2f(s_aux_val*1.4426950408889634 - qk_max_tmp);
}
f32x8_t logit32;
if(half_lane<seq_len_pad){
f16x8_t logit16 = *reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*2*PARTITION_SIZE+thread_idx*8);
for(int ii=0;ii<8;ii++){
logit32[ii]=__builtin_amdgcn_exp2f(to_float(logit16[ii])-qk_max_tmp);
exp_sum+=logit32[ii];
}
// printf("tid=%d,logit32=%.4f,%.4f,%.4f,%.4f, %.4f,%.4f,%.4f,%.4f\n",thread_idx,logit32[0],logit32[1],logit32[2],logit32[3],logit32[4],logit32[5],logit32[6],logit32[7]);
}
for (int mask = 16; mask >= 1; mask /= 2) {
exp_sum += __shfl_xor(exp_sum, mask);
}
exp_sum += sink_contrib;
// printf("tid=%d,exp_sum=%f\n",thread_idx,exp_sum);
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
if(half_lane<seq_len_pad){
f16x8_t logit16;
for(int ii=0;ii<8;ii++){
scalar_t t;
from_float(t,logit32[ii]*inv_sum);
logit16[ii]=t;
}
*reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*2*PARTITION_SIZE+thread_idx*8)=logit16;
if(num_partitions>1&&half_lane==0){
max_out[real_line] = qk_max_tmp;
expsum_out[real_line] = exp_sum;
}
}
}
}
}
else if(PARTITION_SIZE==512){
for(int lineid = warp_idx;lineid<num_queries_per_kv;lineid+=NUM_WARPS){
if(lineid<num_queries_per_kv){
float qk_max_tmp;
float exp_sum=0;
if(lane==0){
int smax_offset = lineid*4;
qk_max_tmp=s_max[smax_offset];
for(int i=1;i<4;i++){
qk_max_tmp=fmaxf(qk_max_tmp,s_max[smax_offset+i]);
}
}
qk_max_tmp=__shfl(qk_max_tmp,0);
int seq_len_pad = DIVIDE_ROUND_UP(num_tokens,8);
using f16x8_t = __attribute__( (__vector_size__(8 * sizeof(scalar_t)) )) scalar_t;
using f32x8_t = __attribute__( (__vector_size__(8 * sizeof(float)) )) float;
float sink_contrib = 0.f;
if (s_aux_ptr != nullptr && partition_idx == 0) {
float s_aux_val = to_float(smem_s_aux[head_idx+lineid]); // Convert scalar_t (fp16/bf16) to float
sink_contrib = __builtin_amdgcn_exp2f(s_aux_val*1.4426950408889634 - qk_max_tmp);
}
f32x8_t logit32;
if(lane<seq_len_pad){
f16x8_t logit16 = *reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*PARTITION_SIZE+thread_idx*8);
for(int ii=0;ii<8;ii++){
logit32[ii]=__builtin_amdgcn_exp2f(to_float(logit16[ii])-qk_max_tmp);
exp_sum+=logit32[ii];
}
// printf("tid=%d,logit32=%.4f,%.4f,%.4f,%.4f, %.4f,%.4f,%.4f,%.4f\n",thread_idx,logit32[0],logit32[1],logit32[2],logit32[3],logit32[4],logit32[5],logit32[6],logit32[7]);
}
for (int mask = 32; mask >= 1; mask /= 2) {
exp_sum += __shfl_xor(exp_sum, mask);
}
exp_sum += sink_contrib;
// printf("tid=%d,exp_sum=%f\n",thread_idx,exp_sum);
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
if(lane<seq_len_pad){
f16x8_t logit16;
for(int ii=0;ii<8;ii++){
scalar_t t;
from_float(t,logit32[ii]*inv_sum);
logit16[ii]=t;
}
*reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*PARTITION_SIZE+thread_idx*8)=logit16;
if(num_partitions>1&&lane==0){
max_out[lineid] = qk_max_tmp;
expsum_out[lineid] = exp_sum;
}
}
}
}
}
__syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, 16*NUM_WARPS);//2
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t accs[Mloop][NUM_ROWS_PER_THREAD];
for(int m=0;m<Mloop;m++){
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[m][i] = {0.f,0.f,0.f,0.f};
}
}
constexpr int vecsize=BLOCK_SIZE/16;
using half4_vec = half4vec<vecsize>;
using uint8x4_vec = uint8x4vec<vecsize>;
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx ++) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*(BLOCK_SIZE/4);
half4_vec logits_vec[Mloop];
for(int m=0;m<Mloop;m++){
for(int i=0;i<vecsize;i++){
logits_vec[m].data[i]={0,0,0,0};
}
}
for(int m=0;m<Mloop;m++){
int real_row=rowid+m*16;
if(real_row<num_queries_per_kv){
for(int k=0;k<vecsize;k++){
logits_vec[m].data[k] = *reinterpret_cast<half4_t*>(logits + real_row * PARTITION_SIZE+token_idx - start_token_idx + k*4);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE;
half4_vec v_vec;
if constexpr(is_fp8){
uint8x4_vec vecu8 = *reinterpret_cast<const uint8x4_vec*>(v_ptr + offset);
scalar_t *p1=(scalar_t*)&v_vec;
uint8_t *p2=(uint8_t*)&vecu8;
for(int ii=0;ii<vecsize*4;ii++){
p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]);
}
}else{
v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset);
}
for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]);
}
}
}
}
else{
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*vecsize*4+rowid*BLOCK_SIZE;
half4_vec v_vec;
if constexpr(is_fp8){
uint8x4_vec vecu8 = *reinterpret_cast<const uint8x4_vec*>(v_ptr + offset);
scalar_t *p1=(scalar_t*)&v_vec;
uint8_t *p2=(uint8_t*)&vecu8;
for(int ii=0;ii<vecsize*4;ii++){
p1[ii]=uint82half<scalar_t,is_e4m3>(p2[ii]);
}
}else{
v_vec=*reinterpret_cast<const half4_vec*>(v_ptr + offset);
}
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4*vecsize; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
}
}
for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_half>(v_vec.data[ii],logits_vec[m].data[ii],accs[m][i]);
}
}
}
}
}
scalar_t* out_ptr_base;
int out_offset;
if(num_partitions>1){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp+out_tmp_offset + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<num_queries_per_kv){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = rowid+16*warp_idx + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale);
// if(reusekvid==0)printf("patition=%d,tid=%d,i=%d,g=%d,acc=%f\n",partition_idx,thread_idx,i,g,accs[i][g]);
}
}
}
if (num_partitions>1&&thread_idx < num_queries_per_kv){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
float * exp_sums=reinterpret_cast<float*>(out_tmp);
float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
*(exp_sums+offset)=expsum_out[thread_idx];
*(max_logits+offset)=max_out[thread_idx];
}
}
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
__global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
scalar_t* out_tmp, // [num_seqs, num_heads,
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions,
int num_heads,
int PARTITION_SIZE) {
extern __shared__ char shared_mem[];
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int lane = threadIdx.x;
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if(num_partitions==1)return;
float* shared_exp_sums=reinterpret_cast<float*>(shared_mem);
float* shared_max_logits=shared_exp_sums+num_partitions;
float max_logit = -FLT_MAX;
float global_exp_sum = 0.0f;
int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
const float * exp_sums=reinterpret_cast<float*>(out_tmp);
const float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
const float* max_logits_ptr = max_logits + offset;
const float* exp_sums_ptr = exp_sums + offset;
const scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr = out_tmp + out_tmp_offset + offset* HEAD_SIZE;
for(int i=lane;i<num_partitions;i+=WARP_SIZE){
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit,l);
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
for(int i=lane;i<num_partitions;i+=WARP_SIZE){
float rescaled_exp_sum = exp_sums_ptr[i] * __builtin_amdgcn_exp2f(shared_max_logits[i] - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
constexpr int vec_size_o=HEAD_SIZE/64;
constexpr int vec_size = vec_size_o==3?4:vec_size_o;
using half_vec= __attribute__( (__vector_size__(vec_size * sizeof(scalar_t)) )) scalar_t;
using float_vec= __attribute__( (__vector_size__(vec_size * sizeof(float)) )) float;
float_vec acc = {0.0f};
half_vec acc_half;
if(lane<HEAD_SIZE/vec_size){
for (int j = 0; j < num_partitions; ++j) {
half_vec tout= *(half_vec*)(tmp_out_ptr + j * HEAD_SIZE + lane * vec_size);
float temp_sum=shared_exp_sums[j]*inv_global_exp_sum;
#pragma unroll
for(int i=0;i<vec_size;i++){
acc[i] += to_float(tout[i])*temp_sum;
}
}
#pragma unroll
for(int i=0;i<vec_size;i++){
scalar_t temp;
from_float(temp,acc[i]);
acc_half[i]=temp;
}
*(half_vec*)(out_ptr+lane*vec_size)=acc_half;
}
}
static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3
if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16;
if(qhead>kv_head*4)return 8;
return 4;
}
void paged_attention_938(
torch::Tensor& out, // [num_seqs,seqlen, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache,// [num_blocks, num_heads, head_size, block_size]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& q_scale,
const c10::optional<torch::Tensor>& k_scale,
const c10::optional<torch::Tensor>& v_scale,
int max_seq_len,
const c10::optional<at::Tensor> &s_aux_,
float *tmp_out_ptr,
int PARTITION_SIZE); // ★ Attention Sinks ★
extern "C"
void paged_attention(
torch::Tensor& out, // [num_seqs,seqlen, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache,// [num_blocks, num_heads, head_size, block_size]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, //auto,int8,fp8/fp8_e4m3
const c10::optional<torch::Tensor>& q_scale,
const c10::optional<torch::Tensor>& k_scale,
const c10::optional<torch::Tensor>& v_scale,
int max_seq_len,
const c10::optional<at::Tensor> &s_aux_) // ★ Attention Sinks ★
{
int num_seqs = query.size(0);
int headsize=query.size(3);
int block_size=key_cache.size(2);
int mtp = query.size(1);
int num_blocks = key_cache.size(0);
int max_num_blocks_per_seq = block_tables.size(1);
int num_heads = query.size(2)*mtp;
int num_kv_heads = key_cache.size(1);
int PARTITION_SIZE=512;
int reusekv=get_reusekv(num_heads,num_kv_heads);
if (max_seq_len<=10||(max_seq_len>=8192&&max_seq_len==max_num_blocks_per_seq*block_size)){
int meanseq = num_blocks*block_size/num_seqs+8192;
int maxseq = 100000000/num_seqs/headsize/num_heads*64;
if(reusekv<=8) maxseq*=2;
max_seq_len=MIN(max_num_blocks_per_seq*block_size,MIN(meanseq,maxseq));
}
int real_reuse_times = num_heads/num_kv_heads;
int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
if(max_num_partitions*num_seqs*num_kv_heads<=160||reusekv>15)PARTITION_SIZE=256;
if(num_seqs*num_kv_heads<=32&&max_seq_len<=32768)PARTITION_SIZE=256;
// if(max_num_partitions*num_seqs*num_kv_heads>200&&real_reuse_times<6&&max_seq_len>30000)PARTITION_SIZE=1024;
if(PA_PARTITION_SIZE!=0)PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
static float* tmp_out_ptr = nullptr;
constexpr int temp_out_size = 110000000;
if(tmp_out_ptr == nullptr){
hipMalloc(&tmp_out_ptr, temp_out_size); // 100m
hipMemset(tmp_out_ptr,0,temp_out_size);
}
if(device_name=="gfx938"&&(key_cache.dtype()==torch::kFloat8_e5m2||key_cache.dtype()==torch::kFloat8_e4m3fn)){
paged_attention_938(out,query,key_cache,value_cache,block_tables,seq_lens,alibi_slopes,q_scale,k_scale,v_scale,max_seq_len,s_aux_,tmp_out_ptr,PARTITION_SIZE);
return;
}
int head_size = query.size(3);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
const float* alibi_slopes_ptr =alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()):nullptr;
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
auto* out_ptr = out.data_ptr();
const float* k_scale_ptr = k_scale? reinterpret_cast<const float*>(k_scale.value().data_ptr()):nullptr;
const float* v_scale_ptr = v_scale? reinterpret_cast<const float*>(v_scale.value().data_ptr()):nullptr;
// Attention Sinks: validate and set s_aux_ptr
const void* s_aux_ptr = nullptr;
if (s_aux_.has_value()) {
auto s_aux = s_aux_.value();
// ★ s_aux must match Q/K/V dtype (Element type) for mixed precision
TORCH_CHECK(s_aux.dtype() == query.dtype(),
"s_aux must have the same dtype as query. Got s_aux dtype: ", s_aux.dtype(),
", query dtype: ", query.dtype());
TORCH_CHECK(s_aux.dtype() == torch::kFloat16 || s_aux.dtype() == torch::kBFloat16,
"s_aux must have dtype float16 or bfloat16 (to match query). Got: ", s_aux.dtype());
TORCH_CHECK(num_heads <= 64,
"Attention Sinks only supports up to 64 heads (shared memory limit), got ", num_heads);
CHECK_DEVICE(s_aux);
CHECK_SHAPE(s_aux, num_heads);
CHECK_CONTIGUOUS(s_aux);
s_aux_ptr = s_aux.data_ptr();
}
auto* query_ptr = query.data_ptr();
auto* key_cache_ptr = key_cache.data_ptr();
auto* value_cache_ptr = value_cache.data_ptr();
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
dim3 grid;
grid.x = num_kv_heads;
grid.y = num_seqs;
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{
Input_Type_SWITCH(query.dtype(),[&]{
Cache_Type_SWITCH(scalar_t,key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int BLOCK_SIZE=128;
// constexpr int HEAD_SIZE=128;
// using scalar_t=_Float16;
// using cache_t = scalar_t;
constexpr bool is_e4m3=false;
// constexpr static int REUSE_KV_TIMES = 4;
constexpr static int NUM_THREADS = 256;
constexpr static int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int other_use = (real_reuse_times*NUM_WARPS+NUM_WARPS+ real_reuse_times*2)*sizeof(float);
int shared_mem_size=PARTITION_SIZE*2*real_reuse_times+other_use;
grid.z = max_num_partitions;
dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n",
(int)(sizeof(cache_t)==1),shared_mem_size,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs,PARTITION_SIZE,max_num_partitions);
paged_attention_kernel<scalar_t,cache_t,is_e4m3,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,REUSE_KV_TIMES><<<grid,block,shared_mem_size,stream>>>(
(scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr, (scalar_t*)query_ptr,(cache_t*) key_cache_ptr, (cache_t*)value_cache_ptr,
num_heads, num_kv_heads, block_tables_ptr, seq_lens_ptr,max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride,
k_scale_ptr, v_scale_ptr,max_num_partitions,PARTITION_SIZE,(const scalar_t*)s_aux_ptr,mtp,alibi_slopes_ptr!=nullptr);
if(max_num_partitions>1){
paged_attention_combine<scalar_t,HEAD_SIZE,64><<<dim3(num_heads,num_seqs),64,4*2*max_num_partitions,stream>>>(
(scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr,seq_lens_ptr,max_num_partitions,num_heads,PARTITION_SIZE);
}
});
});
});
});
});
}
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#define WARP_SIZE 64
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Input validation macros (consistent with flash_api.cpp and flash_api_sparse.cpp)
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using half8_t = __attribute__( (__vector_size__(8 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
using intx2 = __attribute__( (__vector_size__(2 * sizeof(int)) )) int;
using intx4 = __attribute__( (__vector_size__(4 * sizeof(int)) )) int;
static constexpr int LDS_size = 65536;
static constexpr int max_tmp_offset=4000000;
static constexpr int signal_tmp_offset=8000000;
static constexpr int streamk_max_block=160*8;
static constexpr int out_tmp_offset=signal_tmp_offset+streamk_max_block*2;
// static constexpr int PARTITION_SIZE=512;
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template<typename scalar_t>
static __device__ inline void from_float(scalar_t &out ,float f){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
out=f;
}
else{
uint32_t u = *(uint32_t*)(&f);
// u += 0x7fff + ((u >> 16) & 1);
u += 0x8000;
out = u>>16;
}
}
template<typename scalar_t>
static __device__ inline float to_float(scalar_t in){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
return in;
}
else{
union{
uint32_t int32;
float fp32;
} u = {uint32_t(in) << 16};
return u.fp32;
}
}
inline __device__ float uint82float(const uint8_t& input) {
#if (defined(__gfx938__) )
return __builtin_hcu_cvt_f32_fp8(input,false,0,0);
#else
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
#endif
}
template<typename scalar_t,bool is_e4m3>
__forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
union uf16{
uint16_t as_bits;
_Float16 as_value;
} ;
union uf32 {
uint32_t as_bits;
float as_value;
};
if constexpr(!is_e4m3){
uf16 u16;
u16.as_bits = (uint16_t)input << 8;
if constexpr(std::is_same<scalar_t, _Float16>::value){
return u16.as_value;
}
else{
uf32 u32;
u32.as_value = (float)u16.as_value;
return u32.as_bits>>16;
}
}
else{
uf32 u32;
u32.as_value = uint82float(input);
if constexpr(std::is_same<scalar_t, _Float16>::value){
return (_Float16)(u32.as_value);
}
else{
return (uint16_t)(u32.as_bits >> 16);
}
}
}
template <bool is_e4m3>
static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
int val=0;
#if (defined(__gfx938__) )
if constexpr(is_e4m3){
val = __builtin_hcu_cvt_pk_fp8_f32(v1,v2,val,false);
val = __builtin_hcu_cvt_pk_fp8_f32(v3,v4,val,true);
}
else{
val = __builtin_hcu_cvt_pk_bf8_f32(v1,v2,val,false);
val = __builtin_hcu_cvt_pk_bf8_f32(v3,v4,val,true);
}
#endif
return val;
}
template <bool is_e4m3>
static __device__ float4_t to_fp32_from_fp8(int val) {
float4_t ret;
#if (defined(__gfx938__) )
if constexpr(is_e4m3){
ret[0] = __builtin_hcu_cvt_f32_fp8(val,false,0,0);
ret[1] = __builtin_hcu_cvt_f32_fp8(val,false,0,1);
ret[2] = __builtin_hcu_cvt_f32_fp8(val,false,0,2);
ret[3] = __builtin_hcu_cvt_f32_fp8(val,false,0,3);
}
else{
ret[0] = __builtin_hcu_cvt_f32_bf8(val,false,0,0);
ret[1] = __builtin_hcu_cvt_f32_bf8(val,false,0,1);
ret[2] = __builtin_hcu_cvt_f32_bf8(val,false,0,2);
ret[3] = __builtin_hcu_cvt_f32_bf8(val,false,0,3);
}
#endif
return ret;
}
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define Output_Type_SWITCH(SRC_DTYPE, ...) \
[&] { \
if (SRC_DTYPE == at::ScalarType::Half) { \
using scalar_t=_Float16; \
return __VA_ARGS__(); \
}else { \
using scalar_t=uint16_t; \
return __VA_ARGS__(); \
} \
}()
#define Input_Type_SWITCH(scalar_t,qdtype,kdtype,...) \
[&] { \
if(qdtype==torch::kFloat8_e5m2){ \
constexpr bool is_e4m3=false; \
using q_type = uint8_t; \
return __VA_ARGS__(); \
}else if(qdtype==torch::kFloat8_e4m3fn){ \
constexpr bool is_e4m3=true; \
using q_type = uint8_t; \
return __VA_ARGS__(); \
}else if(kdtype==torch::kFloat8_e5m2){ \
constexpr bool is_e4m3=false; \
using q_type = scalar_t; \
return __VA_ARGS__(); \
}else{ \
constexpr bool is_e4m3=true; \
using q_type = scalar_t; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==48){ \
constexpr static int REUSE_KV_TIMES = 48; \
return __VA_ARGS__(); \
}else if (reusekv==36){ \
constexpr static int REUSE_KV_TIMES = 36; \
return __VA_ARGS__(); \
}else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \
return __VA_ARGS__(); \
}else if (reusekv==24){ \
constexpr static int REUSE_KV_TIMES = 24; \
return __VA_ARGS__(); \
}else if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else { \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(headsize,...) \
[&] { \
if (headsize==64){ \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
}else if(headsize==128){ \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
}else if(headsize==192){ \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
}else { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
}()
static std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
static const std::string device_name=get_device_name();
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
}
return 0;
}
static const int PA_USE_STREAMK = get_env_("PA_USE_STREAMK");
static const int PA_MAX_BLOCKS = get_env_("PA_MAX_BLOCKS");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE");
template<int vec>
struct half4vec{
half4_t data[vec];
};
using half4x2 = half4vec<2>;
using half4x4 = half4vec<4>;
template<int vec>
struct int2vec{
intx2 data[vec];
};
template<int vec>
struct uint8x4vec{
uint8x4_t data[vec];
};
using uint8x4x2 = uint8x4vec<2>;
using uint8x4x4 = uint8x4vec<4>;
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
int warp = __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE);
int lane = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor(sum, mask);
}
if (lane == 0) {
red_smem[warp] = sum;
}
__syncthreads();
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor(sum, mask);
}
return __shfl(sum, 0);
}
template<bool is_e4m3>
inline __device__ void builtin_amdgcn_mmac(const intx2& reg_a, const intx2& reg_b, float4_t& reg_c)
{
#if (defined(__gfx938__) )
if constexpr(is_e4m3){
reg_c=__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(reg_a,reg_b,reg_c,false,false);
}else{
reg_c=__builtin_hcu_mmac_f32_16x16x32_bf8_bf8_lit_lts(reg_a,reg_b,reg_c,false,false);
}
#endif
}
template <typename scalar_t,typename q_type,bool is_e4m3 ,int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, int REUSE_KV_TIMES> // Zero means no partitioning.
__launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size]
scalar_t* __restrict__ out_tmp, // [num_seqs, num_heads, max_num_partitions,head_size]
const q_type* __restrict__ q, // [num_seqs, num_heads, head_size]
const uint8_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const uint8_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads,
const int num_kv_heads, // [num_heads]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,const int kv_block_stride,
const float* q_scale_ptr, const float* k_scale_ptr, const float* v_scale_ptr,
int max_num_partitions,int PARTITION_SIZE,
const scalar_t* __restrict__ s_aux_ptr,int mtp,bool has_abili) { // ★ Attention Sinks: [num_heads] scalar_t ★
#if (defined(__gfx938__) )
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
constexpr int kv_head_stride=BLOCK_SIZE*HEAD_SIZE;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if(num_partitions<=partition_idx)return ;
constexpr bool is_half = std::is_same<scalar_t, _Float16>::value;
constexpr bool q_is_fp8 = std::is_same<q_type, uint8_t>::value;
constexpr float scale = (HEAD_SIZE==64?0.125f:(HEAD_SIZE==128? 0.0883883476f:(HEAD_SIZE==192?0.0721687836f:0.0625f)))*1.4426950408889634;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16;
const int rows = lane/16;
float k_scale=scale;
float v_scale=1.0;
float q_scale=1.0;
if(k_scale_ptr!=nullptr){
k_scale*=(*k_scale_ptr);
}
if(q_scale_ptr!=nullptr){
q_scale=*q_scale_ptr;
}
if(v_scale_ptr!=nullptr){
v_scale=*v_scale_ptr;
}
k_scale*=q_scale;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int head_idx=blockIdx.x*num_queries_per_kv;
const int kv_head_idx = blockIdx.x;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1;
extern __shared__ char shared_mem[];
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
float* s_max = reinterpret_cast<float*>(shared_mem + sizeof(scalar_t)*num_queries_per_kv*PARTITION_SIZE);
float* s_logit = s_max + num_queries_per_kv * NUM_WARPS;
float* max_out = s_logit+NUM_WARPS;
float* expsum_out = max_out+num_queries_per_kv;
// ★ Attention Sinks: load s_aux to shared memory ★
__shared__ scalar_t smem_s_aux[64];
if (s_aux_ptr != nullptr) {
if (thread_idx < num_heads) {
smem_s_aux[thread_idx] = s_aux_ptr[thread_idx];
}
__syncthreads();
}
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const q_type* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float alibi_slope[reuse_group]={0.f};
if (has_abili){
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<num_queries_per_kv) alibi_slope[i]=alibi_slopes[head_idx+reuse_kv_idx]*1.4426950408889634;
}
}
float qk_max[reuse_group];
for(int i=0;i<reuse_group;i++){
qk_max[i]=-FLT_MAX;
}
intx4 q_vec[Mloop][HEAD_SIZE/64];
q_type* s_q = reinterpret_cast<q_type*>(shared_mem);
for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){
if constexpr (q_is_fp8){
*reinterpret_cast<intx2*>(s_q+i)=*reinterpret_cast<const intx2*>(q_ptr+i);
}
else{
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+i);
}
}
__syncthreads();
for(int m=0;m<Mloop;m++){
for(int i=0;i<HEAD_SIZE/64;i++){
int head_idx_=rowid+16*m;
if(head_idx_<num_queries_per_kv) {
if constexpr(q_is_fp8){
q_vec[m][i]=*reinterpret_cast<const intx4*>(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*16);
}
else{
auto q_temp = *reinterpret_cast<const half4x4*>(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*16);
scalar_t *q_temp_ptr=(scalar_t*)&q_temp;
q_vec[m][i][0]=to_f8_from_f32<is_e4m3>(to_float(q_temp_ptr[0])/q_scale,to_float(q_temp_ptr[1])/q_scale,to_float(q_temp_ptr[2])/q_scale,to_float(q_temp_ptr[3])/q_scale);
q_vec[m][i][1]=to_f8_from_f32<is_e4m3>(to_float(q_temp_ptr[4])/q_scale,to_float(q_temp_ptr[5])/q_scale,to_float(q_temp_ptr[6])/q_scale,to_float(q_temp_ptr[7])/q_scale);
q_vec[m][i][2]=to_f8_from_f32<is_e4m3>(to_float(q_temp_ptr[8])/q_scale,to_float(q_temp_ptr[9])/q_scale,to_float(q_temp_ptr[10])/q_scale,to_float(q_temp_ptr[11])/q_scale);
q_vec[m][i][3]=to_f8_from_f32<is_e4m3>(to_float(q_temp_ptr[12])/q_scale,to_float(q_temp_ptr[13])/q_scale,to_float(q_temp_ptr[14])/q_scale,to_float(q_temp_ptr[15])/q_scale);
}
}
else q_vec[m][i]={0,0,0,0};
}
}
__syncthreads();
const int start_block_idx = partition_idx * PARTITION_SIZE / BLOCK_SIZE;
const int end_block_idx =MIN(start_block_idx + PARTITION_SIZE / BLOCK_SIZE, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
//compute q*k
{
const uint8_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
#pragma unroll
for(int b=0;b<BLOCK_SIZE;b+=16){
const uint8_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride + b*HEAD_SIZE;
float4_t qk_vec[Mloop];
for(int m=0;m<Mloop;m++){
qk_vec[m]={0,0,0,0};
}
#pragma unroll
for(int i=0;i<HEAD_SIZE/64;i++){
intx4 k_vec=*reinterpret_cast<const intx4*>(k_ptr+i*64+rowid*HEAD_SIZE+rows*16);
intx2 *k_vec_2 = (intx2*)&k_vec;
for(int m=0;m<Mloop;m++){
intx2 *q_vec_2 = (intx2*)(&q_vec[m][i]);
builtin_amdgcn_mmac<is_e4m3>(k_vec_2[0],q_vec_2[0],qk_vec[m]);
builtin_amdgcn_mmac<is_e4m3>(k_vec_2[1],q_vec_2[1],qk_vec[m]);
}
}
#pragma unroll
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
int m = reuse_kv_idx/16;
int ii = i%4;
if(reuse_kv_idx<num_queries_per_kv){
qk_vec[m][ii]*=k_scale;
const int token_idx = block_idx * BLOCK_SIZE+rowid + b;
if (has_abili){
float alibi=alibi_slope[i] * (token_idx - seq_len + 1);
qk_vec[m][ii] += alibi;
}
if(token_idx >= seq_len) {
int seq_len_pad=DIVIDE_ROUND_UP(seq_len,8)*8;
if(token_idx<seq_len_pad) from_float(logits[PARTITION_SIZE*reuse_kv_idx+token_idx - start_token_idx],-INFINITY);
else logits[PARTITION_SIZE*reuse_kv_idx+token_idx - start_token_idx]=0;
}
else{
scalar_t temp;
if (mtp>1){
int casual = mtp - reuse_kv_idx * mtp / num_heads ;
if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY;
}
from_float(temp,qk_vec[m][ii]);
logits[PARTITION_SIZE*reuse_kv_idx+token_idx- start_token_idx]=temp;
qk_max[i] = fmaxf(qk_max[i], to_float(temp));
// if(partition_idx==0)printf("tid=%d,tokenid=%d,reuse_kv_idx=%d,m=%d,ii=%d,qk=%f\n",thread_idx,token_idx,reuse_kv_idx,m,i,qk_vec[m][ii]);
}
}
}
}
}
}
// compute max
#pragma unroll
for (int mask = 8; mask >= 1; mask /= 2) {
#pragma unroll
for(int r=0;r<reuse_group;r++){
qk_max[r]=fmaxf(qk_max[r],__shfl_xor(qk_max[r],mask));
}
}
#pragma unroll
for(int r=0;r<reuse_group;r++){
if(rowid==0&&r*4+rows<num_queries_per_kv){
s_max[(r*4+rows)*NUM_WARPS+warp_idx] = qk_max[r];
}
}
__syncthreads();
if(PARTITION_SIZE==256){
for(int lineid = warp_idx;lineid<REUSE_KV_TIMES/2;lineid+=NUM_WARPS){
int half_lane = lane%32;
int which_half = lane/32;
int real_line=lineid*2+which_half;
if(real_line<num_queries_per_kv){
float qk_max_tmp;
float exp_sum=0;
if(half_lane==0){
int smax_offset = real_line*4;
qk_max_tmp=s_max[smax_offset];
for(int i=1;i<4;i++){
qk_max_tmp=fmaxf(qk_max_tmp,s_max[smax_offset+i]);
}
}
qk_max_tmp=__shfl(qk_max_tmp,which_half*32);
int seq_len_pad = DIVIDE_ROUND_UP(num_tokens,8);
using f16x8_t = __attribute__( (__vector_size__(8 * sizeof(scalar_t)) )) scalar_t;
using f32x8_t = __attribute__( (__vector_size__(8 * sizeof(float)) )) float;
float sink_contrib = 0.f;
if (s_aux_ptr != nullptr && partition_idx == 0) {
float s_aux_val = to_float(smem_s_aux[head_idx+real_line]); // Convert scalar_t (fp16/bf16) to float
sink_contrib = __builtin_amdgcn_exp2f(s_aux_val*1.4426950408889634 - qk_max_tmp);
}
f32x8_t logit32;
if(half_lane<seq_len_pad){
f16x8_t logit16 = *reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*2*PARTITION_SIZE+thread_idx*8);
for(int ii=0;ii<8;ii++){
logit32[ii]=__builtin_amdgcn_exp2f(to_float(logit16[ii])-qk_max_tmp);
exp_sum+=logit32[ii];
}
// printf("tid=%d,logit32=%.4f,%.4f,%.4f,%.4f, %.4f,%.4f,%.4f,%.4f\n",thread_idx,logit32[0],logit32[1],logit32[2],logit32[3],logit32[4],logit32[5],logit32[6],logit32[7]);
}
for (int mask = 16; mask >= 1; mask /= 2) {
exp_sum += __shfl_xor(exp_sum, mask);
}
exp_sum += sink_contrib;
// printf("tid=%d,exp_sum=%f\n",thread_idx,exp_sum);
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
if(half_lane<seq_len_pad){
f16x8_t logit16;
for(int ii=0;ii<8;ii++){
scalar_t t;
from_float(t,logit32[ii]*inv_sum);
logit16[ii]=t;
}
*reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*2*PARTITION_SIZE+thread_idx*8)=logit16;
if(num_partitions>1&&half_lane==0){
max_out[real_line] = qk_max_tmp;
expsum_out[real_line] = exp_sum;
}
}
}
}
}
else if(PARTITION_SIZE==512){
for(int lineid = warp_idx;lineid<num_queries_per_kv;lineid+=NUM_WARPS){
if(lineid<num_queries_per_kv){
float qk_max_tmp;
float exp_sum=0;
if(lane==0){
int smax_offset = lineid*4;
qk_max_tmp=s_max[smax_offset];
for(int i=1;i<4;i++){
qk_max_tmp=fmaxf(qk_max_tmp,s_max[smax_offset+i]);
}
}
qk_max_tmp=__shfl(qk_max_tmp,0);
int seq_len_pad = DIVIDE_ROUND_UP(num_tokens,8);
using f16x8_t = __attribute__( (__vector_size__(8 * sizeof(scalar_t)) )) scalar_t;
using f32x8_t = __attribute__( (__vector_size__(8 * sizeof(float)) )) float;
float sink_contrib = 0.f;
if (s_aux_ptr != nullptr && partition_idx == 0) {
float s_aux_val = to_float(smem_s_aux[head_idx+lineid]); // Convert scalar_t (fp16/bf16) to float
sink_contrib = __builtin_amdgcn_exp2f(s_aux_val*1.4426950408889634 - qk_max_tmp);
}
f32x8_t logit32;
if(lane<seq_len_pad){
f16x8_t logit16 = *reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*PARTITION_SIZE+thread_idx*8);
for(int ii=0;ii<8;ii++){
logit32[ii]=__builtin_amdgcn_exp2f(to_float(logit16[ii])-qk_max_tmp);
exp_sum+=logit32[ii];
}
// printf("tid=%d,logit32=%.4f,%.4f,%.4f,%.4f, %.4f,%.4f,%.4f,%.4f\n",thread_idx,logit32[0],logit32[1],logit32[2],logit32[3],logit32[4],logit32[5],logit32[6],logit32[7]);
}
for (int mask = 32; mask >= 1; mask /= 2) {
exp_sum += __shfl_xor(exp_sum, mask);
}
exp_sum += sink_contrib;
// printf("tid=%d,exp_sum=%f\n",thread_idx,exp_sum);
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
if(lane<seq_len_pad){
f16x8_t logit16;
for(int ii=0;ii<8;ii++){
scalar_t t;
from_float(t,logit32[ii]*inv_sum);
logit16[ii]=t;
}
*reinterpret_cast<f16x8_t*>(logits+lineid/NUM_WARPS*NUM_WARPS*PARTITION_SIZE+thread_idx*8)=logit16;
if(num_partitions>1&&lane==0){
max_out[lineid] = qk_max_tmp;
expsum_out[lineid] = exp_sum;
}
}
}
}
}
__syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, 16*NUM_WARPS);//2
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t accs[Mloop][NUM_ROWS_PER_THREAD];
for(int m=0;m<Mloop;m++){
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[m][i] = {0.f,0.f,0.f,0.f};
}
}
constexpr int vecsize=BLOCK_SIZE/32;//2
using int_vec = int2vec<vecsize>;
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx ++) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*(BLOCK_SIZE/4);
intx2 logits_vec[Mloop][vecsize];
for(int m=0;m<Mloop;m++){
for(int i=0;i<vecsize;i++){
logits_vec[m][i]={0,0};
}
}
for(int m=0;m<Mloop;m++){
int real_row=rowid+m*16;
if(real_row<num_queries_per_kv){
for(int k=0;k<vecsize;k++){
auto l_temp = *reinterpret_cast<half8_t*>(logits + real_row * PARTITION_SIZE+token_idx - start_token_idx + k*8);
scalar_t *l_temp_ptr=(scalar_t*)&l_temp;
logits_vec[m][k][0]=to_f8_from_f32<is_e4m3>(to_float(l_temp_ptr[0]),to_float(l_temp_ptr[1]),to_float(l_temp_ptr[2]),to_float(l_temp_ptr[3]));
logits_vec[m][k][1]=to_f8_from_f32<is_e4m3>(to_float(l_temp_ptr[4]),to_float(l_temp_ptr[5]),to_float(l_temp_ptr[6]),to_float(l_temp_ptr[7]));
}
}
}
const uint8_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
if(partition_idx<num_partitions-1){
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset);
for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_e4m3>(v_vec.data[ii],logits_vec[m][ii],accs[m][i]);
}
}
}
}
else{
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
int offset=i*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD+warp_idx*BLOCK_SIZE*HEAD_SIZE/NUM_ROWS_PER_THREAD/NUM_WARPS+rows*16+rowid*BLOCK_SIZE;
int_vec v_vec = *reinterpret_cast<const int_vec*>(v_ptr + offset);
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if (block_idx == num_seq_blocks - 1) {
uint8_t* v_vec_ptr = reinterpret_cast<uint8_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 16; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : 0;
}
}
for(int ii=0;ii<vecsize;ii++){
for(int m=0;m<Mloop;m++){
builtin_amdgcn_mmac<is_e4m3>(v_vec.data[ii],logits_vec[m][ii],accs[m][i]);
}
}
}
}
}
scalar_t* out_ptr_base;
int out_offset;
if(num_partitions>1){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp+out_tmp_offset + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<num_queries_per_kv){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = rowid+16*warp_idx + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale);
// if(reusekvid==0)printf("patition=%d,tid=%d,i=%d,g=%d,acc=%f\n",partition_idx,thread_idx,i,g,accs[i][g]);
}
}
}
if (num_partitions>1&&thread_idx < num_queries_per_kv){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
float * exp_sums=reinterpret_cast<float*>(out_tmp);
float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
*(exp_sums+offset)=expsum_out[thread_idx];
*(max_logits+offset)=max_out[thread_idx];
}
#endif
}
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
__global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
scalar_t* out_tmp, // [num_seqs, num_heads,
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions,
int num_heads,
int PARTITION_SIZE) {
extern __shared__ char shared_mem[];
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int lane = threadIdx.x;
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if(num_partitions==1)return;
float* shared_exp_sums=reinterpret_cast<float*>(shared_mem);
float* shared_max_logits=shared_exp_sums+num_partitions;
float max_logit = -FLT_MAX;
float global_exp_sum = 0.0f;
int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
const float * exp_sums=reinterpret_cast<float*>(out_tmp);
const float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
const float* max_logits_ptr = max_logits + offset;
const float* exp_sums_ptr = exp_sums + offset;
const scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr = out_tmp + out_tmp_offset + offset* HEAD_SIZE;
for(int i=lane;i<num_partitions;i+=WARP_SIZE){
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit,l);
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
for(int i=lane;i<num_partitions;i+=WARP_SIZE){
float rescaled_exp_sum = exp_sums_ptr[i] * __builtin_amdgcn_exp2f(shared_max_logits[i] - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
constexpr int vec_size_o=HEAD_SIZE/64;
constexpr int vec_size = vec_size_o==3?4:vec_size_o;
using half_vec= __attribute__( (__vector_size__(vec_size * sizeof(scalar_t)) )) scalar_t;
using float_vec= __attribute__( (__vector_size__(vec_size * sizeof(float)) )) float;
float_vec acc = {0.0f};
half_vec acc_half;
if(lane<HEAD_SIZE/vec_size){
for (int j = 0; j < num_partitions; ++j) {
half_vec tout= *(half_vec*)(tmp_out_ptr + j * HEAD_SIZE + lane * vec_size);
float temp_sum=shared_exp_sums[j]*inv_global_exp_sum;
#pragma unroll
for(int i=0;i<vec_size;i++){
acc[i] += to_float(tout[i])*temp_sum;
}
}
#pragma unroll
for(int i=0;i<vec_size;i++){
scalar_t temp;
from_float(temp,acc[i]);
acc_half[i]=temp;
}
*(half_vec*)(out_ptr+lane*vec_size)=acc_half;
}
}
static int get_reusekv(int qhead,int kv_head){
if(qhead>kv_head*36) return 48;
if(qhead>kv_head*32) return 36;//glm4.7 mtp 3
if(qhead>kv_head*24) return 32;
if(qhead>kv_head*16) return 24;
if(qhead>kv_head*8) return 16;
if(qhead>kv_head*4)return 8;
return 4;
}
void paged_attention_938(
torch::Tensor& out, // [num_seqs,seqlen, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache,// [num_blocks, num_heads, head_size, block_size]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& q_scale,
const c10::optional<torch::Tensor>& k_scale,
const c10::optional<torch::Tensor>& v_scale,
int max_seq_len,
const c10::optional<at::Tensor> &s_aux_,
float *tmp_out_ptr,
int PARTITION_SIZE) // ★ Attention Sinks ★
{
int max_num_blocks_per_seq = block_tables.size(1);
int num_seqs = query.size(0);
int mtp = query.size(1);
int block_size=key_cache.size(2);
int num_heads = query.size(2)*mtp;
int num_kv_heads = key_cache.size(1);
int head_size = query.size(3);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_blocks=key_cache.size(0);
const float* alibi_slopes_ptr =alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()):nullptr;
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
auto* out_ptr = out.data_ptr();
const float* q_scale_ptr = q_scale? reinterpret_cast<const float*>(q_scale.value().data_ptr()):nullptr;
const float* k_scale_ptr = k_scale? reinterpret_cast<const float*>(k_scale.value().data_ptr()):nullptr;
const float* v_scale_ptr = v_scale? reinterpret_cast<const float*>(v_scale.value().data_ptr()):nullptr;
// Attention Sinks: validate and set s_aux_ptr
const void* s_aux_ptr = nullptr;
if (s_aux_.has_value()) {
auto s_aux = s_aux_.value();
// ★ s_aux must match Q/K/V dtype (Element type) for mixed precision
TORCH_CHECK(s_aux.dtype() == query.dtype(),
"s_aux must have the same dtype as query. Got s_aux dtype: ", s_aux.dtype(),
", query dtype: ", query.dtype());
TORCH_CHECK(s_aux.dtype() == torch::kFloat16 || s_aux.dtype() == torch::kBFloat16,
"s_aux must have dtype float16 or bfloat16 (to match query). Got: ", s_aux.dtype());
TORCH_CHECK(num_heads <= 64,
"Attention Sinks only supports up to 64 heads (shared memory limit), got ", num_heads);
CHECK_DEVICE(s_aux);
CHECK_SHAPE(s_aux, num_heads);
CHECK_CONTIGUOUS(s_aux);
s_aux_ptr = s_aux.data_ptr();
}
auto* query_ptr = query.data_ptr();
auto* key_cache_ptr = key_cache.data_ptr();
auto* value_cache_ptr = value_cache.data_ptr();
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
dim3 grid;
grid.x = num_kv_heads;
grid.y = num_seqs;
int reusekv=get_reusekv(num_heads,num_kv_heads);
int headsize=query.size(3);
AT_ASSERTM(headsize%64==0 && headsize<=256, "Page Attention head size must be 64, 128, 192 or 256");
AT_ASSERTM(num_heads<=num_kv_heads*48, "Page Attention qheads*mtp/kvheads must be smaller than 48");
HEADSIZE_SWITCH(headsize,[&]{
Output_Type_SWITCH(out.dtype(),[&]{
Input_Type_SWITCH(scalar_t,query.dtype(),key_cache.dtype(),[&] {
REUSEKV_SWITCH(reusekv,[&] {
BOOL_SWITCH(block_size==64,is_block64,[&]{
constexpr int BLOCK_SIZE = (is_block64?64:128);
// constexpr int HEAD_SIZE=128;
// using scalar_t=uint16_t;
// constexpr bool is_e4m3=true;
// constexpr static int REUSE_KV_TIMES = 4;
// constexpr bool has_abili=false;
// constexpr bool use_mtp=false;
constexpr static int NUM_THREADS = 256;
constexpr static int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int real_reuse_times = num_heads/num_kv_heads;
int other_use = (real_reuse_times*NUM_WARPS+NUM_WARPS+ real_reuse_times*2)*sizeof(float);
int shared_mem_size=PARTITION_SIZE*sizeof(scalar_t)*real_reuse_times+other_use;
int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
grid.z = max_num_partitions;
dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("sizeof(q)=%d,shared_mem_size=%d,HEAD_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n",
sizeof(q_type),shared_mem_size,HEAD_SIZE,NUM_THREADS,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs,PARTITION_SIZE,max_num_partitions);
paged_attention_kernel<scalar_t,q_type,is_e4m3,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,REUSE_KV_TIMES><<<grid,block,shared_mem_size,stream>>>(
(scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr, (q_type*)query_ptr,(uint8_t*) key_cache_ptr, (uint8_t*)value_cache_ptr,
num_heads, num_kv_heads, block_tables_ptr, seq_lens_ptr,max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride,
q_scale_ptr,k_scale_ptr, v_scale_ptr,max_num_partitions,PARTITION_SIZE,(const scalar_t*)s_aux_ptr,mtp,alibi_slopes_ptr!=nullptr);
if(max_num_partitions>1){
paged_attention_combine<scalar_t,HEAD_SIZE,64><<<dim3(num_heads,num_seqs),64,4*2*max_num_partitions,stream>>>(
(scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr,seq_lens_ptr,max_num_partitions,num_heads,PARTITION_SIZE);
}
});
});
});
});
});
}
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
#pragma once
// Philox CUDA.
namespace flash {
struct ull2 {
unsigned long long x;
unsigned long long y;
};
__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res;
unsigned long long tmp;
tmp = static_cast<unsigned long long>(a) * b;
// asm ("mul.wide.u32 %0, %1, %2;\n\t"
// : "=l"(tmp)
// : "r"(a), "r"(b));
res = (uint2*)(&tmp);
return *res;
}
__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
__forceinline__ __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;
constexpr unsigned long kPhilox10B = 0xBB67AE85;
uint2 key = reinterpret_cast<uint2&>(seed);
uint4 counter;
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset;
tmp->y = subsequence;
#pragma unroll
for (int i = 0; i < 6; i++) {
counter = philox_single_round(counter, key);
key.x += (kPhilox10A);
key.y += (kPhilox10B);
}
uint4 output = philox_single_round(counter, key);
return output;
}
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment