Commit cc6d659f authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Re-format interface sources

parent 5a683756
...@@ -5,37 +5,43 @@ ...@@ -5,37 +5,43 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace native { namespace native {
enum class ScalarType { enum class ScalarType
{
Half, Half,
BFloat16, BFloat16,
}; };
inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type)
switch(scalar_type) { {
switch(scalar_type)
{
case ScalarType::Half: stream << "Half"; break; case ScalarType::Half: stream << "Half"; break;
case ScalarType::BFloat16: stream << "BFloat16"; break; case ScalarType::BFloat16: stream << "BFloat16"; break;
} }
return stream; return stream;
} }
enum class Fp8KVCacheDataType { enum class Fp8KVCacheDataType
kAuto = 0, {
kFp8E4M3 = 1, kAuto = 0,
kFp8E5M2 = 2, kFp8E4M3 = 1,
kFp8E5M2 = 2,
}; };
struct paged_attention_traits { struct paged_attention_traits
{
ScalarType q_type; ScalarType q_type;
std::string kv_cache_dtype; std::string kv_cache_dtype;
}; };
struct paged_attention_args { struct paged_attention_args
{
int head_size; int head_size;
int num_seqs; int num_seqs;
int num_heads; int num_heads;
int num_kv_heads; int num_kv_heads;
int max_num_blocks_per_seq; int max_num_blocks_per_seq;
int q_stride; int q_stride;
int kv_block_stride; int kv_block_stride;
...@@ -63,9 +69,7 @@ struct paged_attention_args { ...@@ -63,9 +69,7 @@ struct paged_attention_args {
int64_t partition_size; int64_t partition_size;
}; };
void paged_attention( void paged_attention(const paged_attention_traits& traits,
const paged_attention_traits& traits, const paged_attention_args& args,
const paged_attention_args& args, hipStream_t stream);
hipStream_t stream } // namespace native
); \ No newline at end of file
}
\ No newline at end of file
...@@ -25,76 +25,73 @@ void paged_attention( ...@@ -25,76 +25,73 @@ void paged_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] int64_t num_kv_heads,
torch::Tensor& double scale,
value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
int64_t num_kv_heads, double scale, torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] int64_t block_size,
torch::Tensor& context_lens, // [num_seqs] int64_t max_context_len,
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype,
const c10::optional<torch::Tensor>& fp8_out_scale, int64_t partition_size) { double k_scale,
double v_scale,
native::paged_attention_traits traits; const c10::optional<torch::Tensor>& fp8_out_scale,
int64_t partition_size)
traits.q_type = ( {
query.dtype() == at::ScalarType::Half ? native::ScalarType::Half
: native::ScalarType::BFloat16 native::paged_attention_traits traits;
);
traits.kv_cache_dtype = kv_cache_dtype; traits.q_type = (query.dtype() == at::ScalarType::Half ? native::ScalarType::Half
: native::ScalarType::BFloat16);
native::paged_attention_args args; traits.kv_cache_dtype = kv_cache_dtype;
args.head_size = query.size(2); native::paged_attention_args args;
args.num_seqs = query.size(0); args.head_size = query.size(2);
args.num_heads = query.size(1);
args.head_size = query.size(2); args.num_seqs = query.size(0);
args.max_num_blocks_per_seq = block_tables.size(1); args.num_heads = query.size(1);
args.q_stride = query.stride(0); args.head_size = query.size(2);
args.kv_block_stride = key_cache.stride(0); args.max_num_blocks_per_seq = block_tables.size(1);
args.kv_head_stride = key_cache.stride(1); args.q_stride = query.stride(0);
args.kv_block_stride = key_cache.stride(0);
// NOTE: alibi_slopes is optional. args.kv_head_stride = key_cache.stride(1);
args.alibi_slopes_ptr =
alibi_slopes // NOTE: alibi_slopes is optional.
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) args.alibi_slopes_ptr =
: nullptr; alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;
args.exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); args.exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
args.max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); args.max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
args.tmp_out_ptr = tmp_out.data_ptr(); args.tmp_out_ptr = tmp_out.data_ptr();
args.query_ptr = query.data_ptr(); args.query_ptr = query.data_ptr();
args.key_cache_ptr = key_cache.data_ptr(); args.key_cache_ptr = key_cache.data_ptr();
args.value_cache_ptr = value_cache.data_ptr(); args.value_cache_ptr = value_cache.data_ptr();
args.block_tables_ptr = block_tables.data_ptr<int>(); args.block_tables_ptr = block_tables.data_ptr<int>();
args.context_lens_ptr = context_lens.data_ptr<int>(); args.context_lens_ptr = context_lens.data_ptr<int>();
// NOTE: fp8_out_scale is optional. // NOTE: fp8_out_scale is optional.
args.fp8_out_scale_ptr = args.fp8_out_scale_ptr =
fp8_out_scale fp8_out_scale ? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) : nullptr;
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) args.out_ptr = out.data_ptr();
: nullptr;
args.out_ptr = out.data_ptr(); args.block_size = block_size;
args.block_size = block_size; args.max_context_len = max_context_len;
args.num_kv_heads = num_kv_heads;
args.max_context_len = max_context_len; args.partition_size = partition_size;
args.num_kv_heads = num_kv_heads; args.scale = scale;
args.partition_size = partition_size; args.k_scale = k_scale;
args.scale = scale; args.v_scale = v_scale;
args.k_scale = k_scale;
args.v_scale = v_scale; hipStream_t stream = nullptr;
HIP_CHECK_ERROR(hipStreamCreate(&stream));
hipStream_t stream = nullptr;
HIP_CHECK_ERROR(hipStreamCreate(&stream)); native::paged_attention(traits, args, stream);
native::paged_attention(traits, args, stream); HIP_CHECK_ERROR(hipStreamDestroy(stream));
HIP_CHECK_ERROR(hipStreamDestroy(stream));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment