Commit 7a18d241 authored by wooway777's avatar wooway777
Browse files

issue/983 - adapted the optimized paged attention to metax

parent 4cd1f688
......@@ -85,6 +85,9 @@
#define hcclSuccess mcclSuccess
#define hcclCommDestroy mcclCommDestroy
#define hcclAllReduce mcclAllReduce
#define hcGetDevice mcGetDevice
#define hcDeviceAttributeMultiProcessorCount mcDeviceAttributeMultiProcessorCount
#define hcDeviceGetAttribute mcDeviceGetAttribute
#define hcStreamCaptureMode mcStreamCaptureMode
#define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal
#define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal
......
......@@ -19,6 +19,12 @@ using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;
using cuda_fp8_e4m3 = __hpcc_fp8_e4m3;
#ifdef ENABLE_METAX_MC_API
using __nv_bfloat16 = __maca_bfloat16;
#else
using __nv_bfloat16 = __hpcc_bfloat16;
#endif
namespace device::metax {
// get the memory offset of the given element in a tensor given its flat index
......
#ifndef __PAGED_ATTENTION_METAX_H__
#define __PAGED_ATTENTION_METAX_H__
#include "../paged_attention.h"
DESCRIPTOR(metax)
#endif // __PAGED_ATTENTION_METAX_H__
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "paged_attention_metax.h"
namespace op::paged_attention::metax {
infiniStatus_t launch_decode_hd64_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd64_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd64_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info_res);
auto info = info_res.take();
// Reserve workspace for optional split-kv decode (partial acc + m/l).
// Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits.
constexpr size_t kMaxSplits = 8;
const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float);
const size_t workspace_bytes = kMaxSplits * per_split;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *alibi_slopes,
void *stream_) const {
bool need_workspace = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
// "auto" may enable split-kv depending on the runtime heuristic.
need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
} else {
// Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace.
need_workspace = (_info.head_size == 128);
}
if (need_workspace && workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
if (_info.index_dtype == INFINI_DTYPE_I64) {
const auto *block_table_i64 = static_cast<const int64_t *>(block_tables);
const auto *cache_lens_i64 = static_cast<const int64_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
const auto *block_table_i32 = static_cast<const int32_t *>(block_tables);
const auto *cache_lens_i32 = static_cast<const int32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
const auto *block_table_u32 = static_cast<const uint32_t *>(block_tables);
const auto *cache_lens_u32 = static_cast<const uint32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention::metax
......@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_attention_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
......@@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#ifdef ENABLE_NVIDIA_API
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#endif
#include <cstdint>
#include <type_traits>
......
#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__
#define __PAGED_ATTENTION_PREFILL_METAX_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(metax)
#endif // __PAGED_ATTENTION_PREFILL_METAX_H__
......@@ -5,6 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
......@@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __PAGED_CACHING_METAX_H__
#define __PAGED_CACHING_METAX_H__
#include "../paged_caching.h"
DESCRIPTOR(metax)
#endif // __PAGED_CACHING_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_metax.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_METAX_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::metax {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)k_cache,
(cuda_bfloat16 *)v_cache,
(const cuda_bfloat16 *)k,
(const cuda_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
// Dispatch logic based on the device's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<METAX_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::metax
......@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_caching_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
......@@ -29,9 +29,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -50,9 +50,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -75,9 +75,9 @@ __C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -95,9 +95,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment