"...composable_kernel.git" did not exist on "3af8c81a72b5b5a0155eb0e95c4f0aba1b375cca"
Commit cc6d659f authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Re-format interface sources

parent 5a683756
...@@ -5,37 +5,43 @@ ...@@ -5,37 +5,43 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace native { namespace native {
enum class ScalarType { enum class ScalarType
{
Half, Half,
BFloat16, BFloat16,
}; };
inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type)
switch(scalar_type) { {
switch(scalar_type)
{
case ScalarType::Half: stream << "Half"; break; case ScalarType::Half: stream << "Half"; break;
case ScalarType::BFloat16: stream << "BFloat16"; break; case ScalarType::BFloat16: stream << "BFloat16"; break;
} }
return stream; return stream;
} }
enum class Fp8KVCacheDataType { enum class Fp8KVCacheDataType
kAuto = 0, {
kFp8E4M3 = 1, kAuto = 0,
kFp8E5M2 = 2, kFp8E4M3 = 1,
kFp8E5M2 = 2,
}; };
struct paged_attention_traits { struct paged_attention_traits
{
ScalarType q_type; ScalarType q_type;
std::string kv_cache_dtype; std::string kv_cache_dtype;
}; };
struct paged_attention_args { struct paged_attention_args
{
int head_size; int head_size;
int num_seqs; int num_seqs;
int num_heads; int num_heads;
int num_kv_heads; int num_kv_heads;
int max_num_blocks_per_seq; int max_num_blocks_per_seq;
int q_stride; int q_stride;
int kv_block_stride; int kv_block_stride;
...@@ -63,9 +69,7 @@ struct paged_attention_args { ...@@ -63,9 +69,7 @@ struct paged_attention_args {
int64_t partition_size; int64_t partition_size;
}; };
void paged_attention( void paged_attention(const paged_attention_traits& traits,
const paged_attention_traits& traits, const paged_attention_args& args,
const paged_attention_args& args, hipStream_t stream);
hipStream_t stream } // namespace native
); \ No newline at end of file
}
\ No newline at end of file
...@@ -21,234 +21,210 @@ ...@@ -21,234 +21,210 @@
#include "paged_attention.hpp" #include "paged_attention.hpp"
#include "kernel/paged_attention_kernel.hpp" #include "kernel/paged_attention_kernel.hpp"
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \ paged_attention_ll4mi_QKV_kernel<T, \
HEAD_SIZE, NTHR, GQA_RATIO> \ KVT, \
<<<grid, block, 0, stream>>>( \ KV_DTYPE, \
query_ptr, key_cache_ptr, value_cache_ptr, args.num_kv_heads, args.scale, \ OUTT, \
args.block_tables_ptr, args.context_lens_ptr, args.max_num_blocks_per_seq, \ BLOCK_SIZE, \
args.alibi_slopes_ptr, args.q_stride, args.kv_block_stride, args.kv_head_stride, \ HEAD_SIZE, \
args.exp_sums_ptr, args.max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ NTHR, \
args.k_scale, args.v_scale, args.fp8_out_scale_ptr); GQA_RATIO> \
<<<grid, block, 0, stream>>>(query_ptr, \
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ key_cache_ptr, \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \ value_cache_ptr, \
PARTITION_SIZE, NPAR_LOOPS> \ args.num_kv_heads, \
<<<reduce_grid, reduce_block, 0, stream>>>( \ args.scale, \
out_ptr, args.exp_sums_ptr, args.max_logits_ptr, tmp_out_ptr, \ args.block_tables_ptr, \
args.context_lens_ptr, max_num_partitions, args.fp8_out_scale_ptr); args.context_lens_ptr, \
args.max_num_blocks_per_seq, \
args.alibi_slopes_ptr, \
args.q_stride, \
args.kv_block_stride, \
args.kv_head_stride, \
args.exp_sums_ptr, \
args.max_logits_ptr, \
tmp_out_ptr, \
out_ptr, \
max_ctx_blocks, \
args.k_scale, \
args.v_scale, \
args.fp8_out_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>(out_ptr, \
args.exp_sums_ptr, \
args.max_logits_ptr, \
tmp_out_ptr, \
args.context_lens_ptr, \
max_num_partitions, \
args.fp8_out_scale_ptr);
namespace { namespace {
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE, template <typename T,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE> typename KVT,
void paged_attention_custom_launcher( vllm::Fp8KVCacheDataType KV_DTYPE,
const native::paged_attention_args& args, int BLOCK_SIZE,
hipStream_t stream) { int HEAD_SIZE,
typename OUTT,
T* tmp_out_ptr = reinterpret_cast<T*>(args.tmp_out_ptr); int PARTITION_SIZE>
T* query_ptr = reinterpret_cast<T*>(args.query_ptr); void paged_attention_custom_launcher(const native::paged_attention_args& args, hipStream_t stream)
KVT* key_cache_ptr = reinterpret_cast<KVT*>(args.key_cache_ptr); {
KVT* value_cache_ptr = reinterpret_cast<KVT*>(args.value_cache_ptr);
OUTT* out_ptr = reinterpret_cast<OUTT*>(args.out_ptr); T* tmp_out_ptr = reinterpret_cast<T*>(args.tmp_out_ptr);
T* query_ptr = reinterpret_cast<T*>(args.query_ptr);
const int max_ctx_blocks = DIVIDE_ROUND_UP(args.max_context_len, BLOCK_SIZE); KVT* key_cache_ptr = reinterpret_cast<KVT*>(args.key_cache_ptr);
const int max_num_partitions = KVT* value_cache_ptr = reinterpret_cast<KVT*>(args.value_cache_ptr);
DIVIDE_ROUND_UP(args.max_context_len, PARTITION_SIZE); OUTT* out_ptr = reinterpret_cast<OUTT*>(args.out_ptr);
const int gqa_ratio = args.num_heads / args.num_kv_heads;
assert(args.num_heads % args.num_kv_heads == 0); const int max_ctx_blocks = DIVIDE_ROUND_UP(args.max_context_len, BLOCK_SIZE);
assert(args.head_size == HEAD_SIZE); const int max_num_partitions = DIVIDE_ROUND_UP(args.max_context_len, PARTITION_SIZE);
const int gqa_ratio = args.num_heads / args.num_kv_heads;
constexpr int NTHR = PARTITION_SIZE; assert(args.num_heads % args.num_kv_heads == 0);
dim3 grid(args.num_seqs, max_num_partitions, args.num_kv_heads); assert(args.head_size == HEAD_SIZE);
dim3 block(NTHR);
constexpr int NTHR = PARTITION_SIZE;
switch (gqa_ratio) { dim3 grid(args.num_seqs, max_num_partitions, args.num_kv_heads);
case 1: dim3 block(NTHR);
LAUNCH_CUSTOM_ATTENTION(1);
break; switch(gqa_ratio)
case 2: {
LAUNCH_CUSTOM_ATTENTION(2); case 1: LAUNCH_CUSTOM_ATTENTION(1); break;
break; case 2: LAUNCH_CUSTOM_ATTENTION(2); break;
case 3: case 3: LAUNCH_CUSTOM_ATTENTION(3); break;
LAUNCH_CUSTOM_ATTENTION(3); case 4: LAUNCH_CUSTOM_ATTENTION(4); break;
break; case 5: LAUNCH_CUSTOM_ATTENTION(5); break;
case 4: case 6: LAUNCH_CUSTOM_ATTENTION(6); break;
LAUNCH_CUSTOM_ATTENTION(4); case 7: LAUNCH_CUSTOM_ATTENTION(7); break;
break; case 8: LAUNCH_CUSTOM_ATTENTION(8); break;
case 5: case 9: LAUNCH_CUSTOM_ATTENTION(9); break;
LAUNCH_CUSTOM_ATTENTION(5); case 10: LAUNCH_CUSTOM_ATTENTION(10); break;
break; case 11: LAUNCH_CUSTOM_ATTENTION(11); break;
case 6: case 12: LAUNCH_CUSTOM_ATTENTION(12); break;
LAUNCH_CUSTOM_ATTENTION(6); case 13: LAUNCH_CUSTOM_ATTENTION(13); break;
break; case 14: LAUNCH_CUSTOM_ATTENTION(14); break;
case 7: case 15: LAUNCH_CUSTOM_ATTENTION(15); break;
LAUNCH_CUSTOM_ATTENTION(7); case 16: LAUNCH_CUSTOM_ATTENTION(16); break;
break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break;
case 8: }
LAUNCH_CUSTOM_ATTENTION(8);
break; // reduction kernel is only required if max_context_len > partition size,
case 9: // otherwise main kernel writes directly to final output
LAUNCH_CUSTOM_ATTENTION(9); // note there are cases with graphing where max_context_len is the max
break; // supported by graphing, not the actual max among all the sequences: in that
case 10: // case reduction kernel will still run but return immediately
LAUNCH_CUSTOM_ATTENTION(10); if(args.max_context_len > PARTITION_SIZE)
break; {
case 11: dim3 reduce_grid(args.num_heads, args.num_seqs);
LAUNCH_CUSTOM_ATTENTION(11); dim3 reduce_block(args.head_size);
break; const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE);
case 12: // support upto 8*64*256=128K context length
LAUNCH_CUSTOM_ATTENTION(12); switch(npar_loops)
break; {
case 13: case 1: LAUNCH_CUSTOM_REDUCTION(1); break;
LAUNCH_CUSTOM_ATTENTION(13); case 2: LAUNCH_CUSTOM_REDUCTION(2); break;
break; case 3: LAUNCH_CUSTOM_REDUCTION(3); break;
case 14: case 4: LAUNCH_CUSTOM_REDUCTION(4); break;
LAUNCH_CUSTOM_ATTENTION(14); case 5: LAUNCH_CUSTOM_REDUCTION(5); break;
break; case 6: LAUNCH_CUSTOM_REDUCTION(6); break;
case 15: case 7: LAUNCH_CUSTOM_REDUCTION(7); break;
LAUNCH_CUSTOM_ATTENTION(15); case 8: LAUNCH_CUSTOM_REDUCTION(8); break;
break; default: TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break;
case 16: }
LAUNCH_CUSTOM_ATTENTION(16);
break;
default:
TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
break;
}
// reduction kernel is only required if max_context_len > partition size,
// otherwise main kernel writes directly to final output
// note there are cases with graphing where max_context_len is the max
// supported by graphing, not the actual max among all the sequences: in that
// case reduction kernel will still run but return immediately
if (args.max_context_len > PARTITION_SIZE) {
dim3 reduce_grid(args.num_heads, args.num_seqs);
dim3 reduce_block(args.head_size);
const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE);
// support upto 8*64*256=128K context length
switch (npar_loops) {
case 1:
LAUNCH_CUSTOM_REDUCTION(1);
break;
case 2:
LAUNCH_CUSTOM_REDUCTION(2);
break;
case 3:
LAUNCH_CUSTOM_REDUCTION(3);
break;
case 4:
LAUNCH_CUSTOM_REDUCTION(4);
break;
case 5:
LAUNCH_CUSTOM_REDUCTION(5);
break;
case 6:
LAUNCH_CUSTOM_REDUCTION(6);
break;
case 7:
LAUNCH_CUSTOM_REDUCTION(7);
break;
case 8:
LAUNCH_CUSTOM_REDUCTION(8);
break;
default:
TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops);
break;
} }
}
}
} }
} // namespace
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE>(args, \
stream);
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ #define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT) \
PSIZE) \ switch(args.partition_size) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ { \
PSIZE>(args, stream); case 256: CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); break; \
case 512: CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); break; \
#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ default: TORCH_CHECK(false, "Unsupported partition size: ", args.partition_size); break; \
OUTT) \ }
switch (args.partition_size) { \
case 256: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \
break; \
case 512: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \
break; \
default: \
TORCH_CHECK(false, "Unsupported partition size: ", args.partition_size); \
break; \
}
#if defined(__HIPCC__) && defined(__gfx90a__) #if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (args.fp8_out_scale_ptr) { \ if(args.fp8_out_scale_ptr) \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ { \
} else { \ TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ } \
else \
{ \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
} }
#else #else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (args.fp8_out_scale_ptr) { \ if(args.fp8_out_scale_ptr) \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ { \
uint8_t); \ CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, uint8_t); \
} else { \ } \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ else \
{ \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
} }
#endif #endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (args.block_size) { \ switch(args.block_size) \
case 16: \ { \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ case 16: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); break; \
break; \ case 32: CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); break; \
case 32: \ default: TORCH_CHECK(false, "Unsupported block size: ", args.block_size); break; \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ }
break; \
default: \ #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
TORCH_CHECK(false, "Unsupported block size: ", args.block_size); \ switch(args.head_size) \
break; \ { \
} case 64: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); break; \
case 128: CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); break; \
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ default: TORCH_CHECK(false, "Unsupported head size: ", args.head_size); break; \
switch (args.head_size) { \ }
case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
break; \
case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", args.head_size); \
break; \
}
namespace native { namespace native {
void paged_attention( void paged_attention(const paged_attention_traits& traits,
const paged_attention_traits& traits, const paged_attention_args& args,
const paged_attention_args& args, hipStream_t stream)
hipStream_t stream
)
{ {
if (traits.kv_cache_dtype == "auto") { if(traits.kv_cache_dtype == "auto")
if (traits.q_type == ScalarType::Half) { {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, if(traits.q_type == ScalarType::Half)
vllm::Fp8KVCacheDataType::kAuto); {
} else if (traits.q_type == ScalarType::BFloat16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto);
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, }
vllm::Fp8KVCacheDataType::kAuto); else if(traits.q_type == ScalarType::BFloat16)
} else { {
TORCH_CHECK(false, "Unsupported data type: ", traits.q_type); CALL_CUSTOM_LAUNCHER_BLK_HEAD(
__hip_bfloat16, __hip_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
}
else
{
TORCH_CHECK(false, "Unsupported data type: ", traits.q_type);
}
} }
} else if (traits.kv_cache_dtype == "fp8" || traits.kv_cache_dtype == "fp8_e4m3") { else if(traits.kv_cache_dtype == "fp8" || traits.kv_cache_dtype == "fp8_e4m3")
if (traits.q_type == ScalarType::Half) { {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, if(traits.q_type == ScalarType::Half)
vllm::Fp8KVCacheDataType::kFp8E4M3); {
} else if (traits.q_type == ScalarType::BFloat16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, }
vllm::Fp8KVCacheDataType::kFp8E4M3); else if(traits.q_type == ScalarType::BFloat16)
} else { {
TORCH_CHECK(false, "Unsupported data type: ", traits.q_type); CALL_CUSTOM_LAUNCHER_BLK_HEAD(
__hip_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
}
else
{
TORCH_CHECK(false, "Unsupported data type: ", traits.q_type);
}
}
else
{
TORCH_CHECK(false, "Unsupported KV cache dtype: ", traits.kv_cache_dtype);
} }
} else {
TORCH_CHECK(false, "Unsupported KV cache dtype: ", traits.kv_cache_dtype);
}
}
} }
} // namespace native
...@@ -25,76 +25,73 @@ void paged_attention( ...@@ -25,76 +25,73 @@ void paged_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] int64_t num_kv_heads,
torch::Tensor& double scale,
value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
int64_t num_kv_heads, double scale, torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] int64_t block_size,
torch::Tensor& context_lens, // [num_seqs] int64_t max_context_len,
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype,
const c10::optional<torch::Tensor>& fp8_out_scale, int64_t partition_size) { double k_scale,
double v_scale,
native::paged_attention_traits traits; const c10::optional<torch::Tensor>& fp8_out_scale,
int64_t partition_size)
traits.q_type = ( {
query.dtype() == at::ScalarType::Half ? native::ScalarType::Half
: native::ScalarType::BFloat16 native::paged_attention_traits traits;
);
traits.kv_cache_dtype = kv_cache_dtype; traits.q_type = (query.dtype() == at::ScalarType::Half ? native::ScalarType::Half
: native::ScalarType::BFloat16);
native::paged_attention_args args; traits.kv_cache_dtype = kv_cache_dtype;
args.head_size = query.size(2); native::paged_attention_args args;
args.num_seqs = query.size(0); args.head_size = query.size(2);
args.num_heads = query.size(1);
args.head_size = query.size(2); args.num_seqs = query.size(0);
args.max_num_blocks_per_seq = block_tables.size(1); args.num_heads = query.size(1);
args.q_stride = query.stride(0); args.head_size = query.size(2);
args.kv_block_stride = key_cache.stride(0); args.max_num_blocks_per_seq = block_tables.size(1);
args.kv_head_stride = key_cache.stride(1); args.q_stride = query.stride(0);
args.kv_block_stride = key_cache.stride(0);
// NOTE: alibi_slopes is optional. args.kv_head_stride = key_cache.stride(1);
args.alibi_slopes_ptr =
alibi_slopes // NOTE: alibi_slopes is optional.
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) args.alibi_slopes_ptr =
: nullptr; alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;
args.exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); args.exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
args.max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); args.max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
args.tmp_out_ptr = tmp_out.data_ptr(); args.tmp_out_ptr = tmp_out.data_ptr();
args.query_ptr = query.data_ptr(); args.query_ptr = query.data_ptr();
args.key_cache_ptr = key_cache.data_ptr(); args.key_cache_ptr = key_cache.data_ptr();
args.value_cache_ptr = value_cache.data_ptr(); args.value_cache_ptr = value_cache.data_ptr();
args.block_tables_ptr = block_tables.data_ptr<int>(); args.block_tables_ptr = block_tables.data_ptr<int>();
args.context_lens_ptr = context_lens.data_ptr<int>(); args.context_lens_ptr = context_lens.data_ptr<int>();
// NOTE: fp8_out_scale is optional. // NOTE: fp8_out_scale is optional.
args.fp8_out_scale_ptr = args.fp8_out_scale_ptr =
fp8_out_scale fp8_out_scale ? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) : nullptr;
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) args.out_ptr = out.data_ptr();
: nullptr;
args.out_ptr = out.data_ptr(); args.block_size = block_size;
args.block_size = block_size; args.max_context_len = max_context_len;
args.num_kv_heads = num_kv_heads;
args.max_context_len = max_context_len; args.partition_size = partition_size;
args.num_kv_heads = num_kv_heads; args.scale = scale;
args.partition_size = partition_size; args.k_scale = k_scale;
args.scale = scale; args.v_scale = v_scale;
args.k_scale = k_scale;
args.v_scale = v_scale; hipStream_t stream = nullptr;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
hipStream_t stream = nullptr;
HIP_CHECK_ERROR(hipStreamCreate(&stream)); native::paged_attention(traits, args, stream);
native::paged_attention(traits, args, stream); HIP_CHECK_ERROR(hipStreamDestroy(stream));
HIP_CHECK_ERROR(hipStreamDestroy(stream));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment