Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
7a18d241
Commit
7a18d241
authored
Jan 26, 2026
by
wooway777
Browse files
issue/983 - adapted the optimized paged attention to metax
parent
4cd1f688
Changes
14
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
3565 additions
and
30 deletions
+3565
-30
src/infiniop/devices/metax/metax_ht2mc.h
src/infiniop/devices/metax/metax_ht2mc.h
+3
-0
src/infiniop/devices/metax/metax_kernel_common.h
src/infiniop/devices/metax/metax_kernel_common.h
+6
-0
src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca
...niop/ops/paged_attention/metax/paged_attention_hd128.maca
+1028
-0
src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca
...iniop/ops/paged_attention/metax/paged_attention_hd64.maca
+528
-0
src/infiniop/ops/paged_attention/metax/paged_attention_metax.h
...nfiniop/ops/paged_attention/metax/paged_attention_metax.h
+8
-0
src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca
...niop/ops/paged_attention/metax/paged_attention_metax.maca
+218
-0
src/infiniop/ops/paged_attention/operator.cc
src/infiniop/ops/paged_attention/operator.cc
+15
-15
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
+2
-0
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h
...d_attention_prefill/metax/paged_attention_prefill_metax.h
+8
-0
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca
...ttention_prefill/metax/paged_attention_prefill_metax.maca
+1554
-0
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+15
-0
src/infiniop/ops/paged_caching/metax/paged_caching_metax.h
src/infiniop/ops/paged_caching/metax/paged_caching_metax.h
+8
-0
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
...infiniop/ops/paged_caching/metax/paged_caching_metax.maca
+157
-0
src/infiniop/ops/paged_caching/operator.cc
src/infiniop/ops/paged_caching/operator.cc
+15
-15
No files found.
src/infiniop/devices/metax/metax_ht2mc.h
View file @
7a18d241
...
@@ -85,6 +85,9 @@
...
@@ -85,6 +85,9 @@
#define hcclSuccess mcclSuccess
#define hcclSuccess mcclSuccess
#define hcclCommDestroy mcclCommDestroy
#define hcclCommDestroy mcclCommDestroy
#define hcclAllReduce mcclAllReduce
#define hcclAllReduce mcclAllReduce
#define hcGetDevice mcGetDevice
#define hcDeviceAttributeMultiProcessorCount mcDeviceAttributeMultiProcessorCount
#define hcDeviceGetAttribute mcDeviceGetAttribute
#define hcStreamCaptureMode mcStreamCaptureMode
#define hcStreamCaptureMode mcStreamCaptureMode
#define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal
#define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal
#define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal
#define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal
...
...
src/infiniop/devices/metax/metax_kernel_common.h
View file @
7a18d241
...
@@ -19,6 +19,12 @@ using cuda_bfloat16 = hpcc_bfloat16;
...
@@ -19,6 +19,12 @@ using cuda_bfloat16 = hpcc_bfloat16;
using
cuda_bfloat162
=
hpcc_bfloat162
;
using
cuda_bfloat162
=
hpcc_bfloat162
;
using
cuda_fp8_e4m3
=
__hpcc_fp8_e4m3
;
using
cuda_fp8_e4m3
=
__hpcc_fp8_e4m3
;
#ifdef ENABLE_METAX_MC_API
using
__nv_bfloat16
=
__maca_bfloat16
;
#else
using
__nv_bfloat16
=
__hpcc_bfloat16
;
#endif
namespace
device
::
metax
{
namespace
device
::
metax
{
// get the memory offset of the given element in a tensor given its flat index
// get the memory offset of the given element in a tensor given its flat index
...
...
src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca
0 → 100644
View file @
7a18d241
This diff is collapsed.
Click to expand it.
src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca
0 → 100644
View file @
7a18d241
This diff is collapsed.
Click to expand it.
src/infiniop/ops/paged_attention/metax/paged_attention_metax.h
0 → 100644
View file @
7a18d241
#ifndef __PAGED_ATTENTION_METAX_H__
#define __PAGED_ATTENTION_METAX_H__
#include "../paged_attention.h"
DESCRIPTOR
(
metax
)
#endif // __PAGED_ATTENTION_METAX_H__
src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca
0 → 100644
View file @
7a18d241
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "paged_attention_metax.h"
namespace op::paged_attention::metax {
infiniStatus_t launch_decode_hd64_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd64_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd64_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info_res);
auto info = info_res.take();
// Reserve workspace for optional split-kv decode (partial acc + m/l).
// Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits.
constexpr size_t kMaxSplits = 8;
const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float);
const size_t workspace_bytes = kMaxSplits * per_split;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *alibi_slopes,
void *stream_) const {
bool need_workspace = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
// "auto" may enable split-kv depending on the runtime heuristic.
need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
} else {
// Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace.
need_workspace = (_info.head_size == 128);
}
if (need_workspace && workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
if (_info.index_dtype == INFINI_DTYPE_I64) {
const auto *block_table_i64 = static_cast<const int64_t *>(block_tables);
const auto *cache_lens_i64 = static_cast<const int64_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
const auto *block_table_i32 = static_cast<const int32_t *>(block_tables);
const auto *cache_lens_i32 = static_cast<const int32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
const auto *block_table_u32 = static_cast<const uint32_t *>(block_tables);
const auto *cache_lens_u32 = static_cast<const uint32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention::metax
src/infiniop/ops/paged_attention/operator.cc
View file @
7a18d241
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
#include "metax/paged_attention_metax.h"
#include "metax/paged_attention_metax.h"
//
#endif
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionDescriptor
(
__C
infiniStatus_t
infiniopCreatePagedAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
...
@@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
CREATE(INFINI_DEVICE_METAX, metax)
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
...
@@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
GET(INFINI_DEVICE_METAX, metax)
GET
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention(
...
@@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
CALCULATE(INFINI_DEVICE_METAX, metax)
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
...
@@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
DESTROY(INFINI_DEVICE_METAX, metax)
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
...
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
View file @
7a18d241
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#ifdef ENABLE_NVIDIA_API
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <mma.h>
#endif
#include <cstdint>
#include <cstdint>
#include <type_traits>
#include <type_traits>
...
...
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h
0 → 100644
View file @
7a18d241
#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__
#define __PAGED_ATTENTION_PREFILL_METAX_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR
(
metax
)
#endif // __PAGED_ATTENTION_PREFILL_METAX_H__
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca
0 → 100644
View file @
7a18d241
This diff is collapsed.
Click to expand it.
src/infiniop/ops/paged_attention_prefill/operator.cc
View file @
7a18d241
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
__C
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
...
@@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
...
@@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
...
@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
...
@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/paged_caching/metax/paged_caching_metax.h
0 → 100644
View file @
7a18d241
#ifndef __PAGED_CACHING_METAX_H__
#define __PAGED_CACHING_METAX_H__
#include "../paged_caching.h"
DESCRIPTOR
(
metax
)
#endif // __PAGED_CACHING_METAX_H__
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
0 → 100644
View file @
7a18d241
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_metax.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_METAX_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::metax {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)k_cache,
(cuda_bfloat16 *)v_cache,
(const cuda_bfloat16 *)k,
(const cuda_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
// Dispatch logic based on the device's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<METAX_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::metax
src/infiniop/ops/paged_caching/operator.cc
View file @
7a18d241
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
#include "metax/paged_caching_metax.h"
#include "metax/paged_caching_metax.h"
//
#endif
#endif
__C
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
__C
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -29,9 +29,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
...
@@ -29,9 +29,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
CREATE(INFINI_DEVICE_METAX, metax)
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -50,9 +50,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
...
@@ -50,9 +50,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
GET(INFINI_DEVICE_METAX, metax)
GET
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -75,9 +75,9 @@ __C infiniStatus_t infiniopPagedCaching(
...
@@ -75,9 +75,9 @@ __C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
CALCULATE(INFINI_DEVICE_METAX, metax)
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -95,9 +95,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
...
@@ -95,9 +95,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
//
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
//
DESTROY(INFINI_DEVICE_METAX, metax)
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
//
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment