Unverified Commit dce99862 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1053 from InfiniTensor/issue/1033xmake

Issue/1033 patch aten and fa adaptations
parents 8d99a8f5 d6e44e84
...@@ -5,15 +5,15 @@ ...@@ -5,15 +5,15 @@
typedef struct InfiniopDescriptor *infiniopSubDescriptor_t; typedef struct InfiniopDescriptor *infiniopSubDescriptor_t;
__C __export infiniStatus_t infiniopCreateSubDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateSubDescriptor(infiniopHandle_t handle,
infiniopSubDescriptor_t *desc_ptr, infiniopSubDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c, infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a, infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b); infiniopTensorDescriptor_t b);
__C __export infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *c, void *c,
...@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc, ...@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
const void *b, const void *b,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc);
#endif #endif
...@@ -5,15 +5,15 @@ ...@@ -5,15 +5,15 @@
typedef struct InfiniopDescriptor *infiniopSwiGLUDescriptor_t; typedef struct InfiniopDescriptor *infiniopSwiGLUDescriptor_t;
__C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle,
infiniopSwiGLUDescriptor_t *desc_ptr, infiniopSwiGLUDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc); infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *c, void *c,
...@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, ...@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void const *b, void const *b,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc);
#endif #endif
...@@ -5,20 +5,20 @@ ...@@ -5,20 +5,20 @@
typedef struct InfiniopDescriptor *infiniopTanhDescriptor_t; typedef struct InfiniopDescriptor *infiniopTanhDescriptor_t;
__C __export infiniStatus_t infiniopCreateTanhDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateTanhDescriptor(infiniopHandle_t handle,
infiniopTanhDescriptor_t *desc_ptr, infiniopTanhDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output, infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input); infiniopTensorDescriptor_t input);
__C __export infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTanh(infiniopTanhDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopTanh(infiniopTanhDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *output, void *output,
const void *input, const void *input,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc);
#endif #endif
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
typedef struct InfiniopDescriptor *infiniopTopkrouterDescriptor_t; typedef struct InfiniopDescriptor *infiniopTopkrouterDescriptor_t;
__C __export infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle,
infiniopTopkrouterDescriptor_t *desc_ptr, infiniopTopkrouterDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc); infiniopTensorDescriptor_t correction_bias_desc);
__C __export infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *values, void *values,
...@@ -23,6 +23,6 @@ __C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t de ...@@ -23,6 +23,6 @@ __C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t de
const size_t topk, const size_t topk,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc);
#endif #endif
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
typedef struct InfiniopDescriptor *infiniopTopksoftmaxDescriptor_t; typedef struct InfiniopDescriptor *infiniopTopksoftmaxDescriptor_t;
__C __export infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
infiniopTopksoftmaxDescriptor_t *desc_ptr, infiniopTopksoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc); infiniopTensorDescriptor_t x_desc);
__C __export infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *values, void *values,
...@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t ...@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t
const int norm, const int norm,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc);
#endif #endif
...@@ -5,20 +5,20 @@ ...@@ -5,20 +5,20 @@
typedef struct InfiniopDescriptor *infiniopZerosDescriptor_t; typedef struct InfiniopDescriptor *infiniopZerosDescriptor_t;
__C __export infiniStatus_t infiniopCreateZerosDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateZerosDescriptor(infiniopHandle_t handle,
infiniopZerosDescriptor_t *desc_ptr, infiniopZerosDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x); infiniopTensorDescriptor_t x);
__C __export infiniStatus_t infiniopGetZerosWorkspaceSize(infiniopZerosDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetZerosWorkspaceSize(infiniopZerosDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopZeros(infiniopZerosDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopZeros(infiniopZerosDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *y, void *y,
const void *x, const void *x,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyZerosDescriptor(infiniopZerosDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroyZerosDescriptor(infiniopZerosDescriptor_t desc);
#endif #endif
...@@ -7,8 +7,8 @@ struct InfiniopTensorDescriptor; ...@@ -7,8 +7,8 @@ struct InfiniopTensorDescriptor;
typedef struct InfiniopTensorDescriptor *infiniopTensorDescriptor_t; typedef struct InfiniopTensorDescriptor *infiniopTensorDescriptor_t;
__C __export infiniStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescriptor_t *desc_ptr, size_t ndim, const size_t *shape, const ptrdiff_t *strides, infiniDtype_t dtype); __INFINI_C __export infiniStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescriptor_t *desc_ptr, size_t ndim, const size_t *shape, const ptrdiff_t *strides, infiniDtype_t dtype);
__C __export infiniStatus_t infiniopDestroyTensorDescriptor(infiniopTensorDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroyTensorDescriptor(infiniopTensorDescriptor_t desc);
#endif // __INFINIOP_TENSOR_DESCRIPTOR__ #endif // __INFINIOP_TENSOR_DESCRIPTOR__
...@@ -10,20 +10,20 @@ typedef void *infinirtGraph_t; ...@@ -10,20 +10,20 @@ typedef void *infinirtGraph_t;
typedef void *infinirtGraphNode_t; typedef void *infinirtGraphNode_t;
typedef void *infinirtGraphExec_t; typedef void *infinirtGraphExec_t;
__C __export infiniStatus_t infinirtInit(); __INFINI_C __export infiniStatus_t infinirtInit();
// Device // Device
__C __export infiniStatus_t infinirtGetAllDeviceCount(int *count_array); __INFINI_C __export infiniStatus_t infinirtGetAllDeviceCount(int *count_array);
__C __export infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count); __INFINI_C __export infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count);
__C __export infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id); __INFINI_C __export infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id);
__C __export infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_ptr); __INFINI_C __export infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_ptr);
__C __export infiniStatus_t infinirtDeviceSynchronize(); __INFINI_C __export infiniStatus_t infinirtDeviceSynchronize();
// Stream // Stream
__C __export infiniStatus_t infinirtStreamCreate(infinirtStream_t *stream_ptr); __INFINI_C __export infiniStatus_t infinirtStreamCreate(infinirtStream_t *stream_ptr);
__C __export infiniStatus_t infinirtStreamDestroy(infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtStreamDestroy(infinirtStream_t stream);
__C __export infiniStatus_t infinirtStreamSynchronize(infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtStreamSynchronize(infinirtStream_t stream);
__C __export infiniStatus_t infinirtStreamWaitEvent(infinirtStream_t stream, infinirtEvent_t event); __INFINI_C __export infiniStatus_t infinirtStreamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
// Event // Event
typedef enum { typedef enum {
...@@ -38,13 +38,13 @@ typedef enum { ...@@ -38,13 +38,13 @@ typedef enum {
INFINIRT_EVENT_BLOCKING_SYNC = 0x2, // Event uses blocking synchronization INFINIRT_EVENT_BLOCKING_SYNC = 0x2, // Event uses blocking synchronization
} infinirtEventFlags_t; } infinirtEventFlags_t;
__C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr); __INFINI_C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr);
__C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags); __INFINI_C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags);
__C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream);
__C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr); __INFINI_C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr);
__C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event); __INFINI_C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event); __INFINI_C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end); __INFINI_C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end);
// Memory // Memory
typedef enum { typedef enum {
...@@ -54,17 +54,17 @@ typedef enum { ...@@ -54,17 +54,17 @@ typedef enum {
INFINIRT_MEMCPY_D2D = 3, INFINIRT_MEMCPY_D2D = 3,
} infinirtMemcpyKind_t; } infinirtMemcpyKind_t;
__C __export infiniStatus_t infinirtMalloc(void **p_ptr, size_t size); __INFINI_C __export infiniStatus_t infinirtMalloc(void **p_ptr, size_t size);
__C __export infiniStatus_t infinirtMallocHost(void **p_ptr, size_t size); __INFINI_C __export infiniStatus_t infinirtMallocHost(void **p_ptr, size_t size);
__C __export infiniStatus_t infinirtFree(void *ptr); __INFINI_C __export infiniStatus_t infinirtFree(void *ptr);
__C __export infiniStatus_t infinirtFreeHost(void *ptr); __INFINI_C __export infiniStatus_t infinirtFreeHost(void *ptr);
__C __export infiniStatus_t infinirtMemcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind); __INFINI_C __export infiniStatus_t infinirtMemcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind);
__C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream);
// Stream-ordered memory // Stream-ordered memory
__C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream);
__C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream);
// Graph // Graph
typedef enum { typedef enum {
...@@ -74,16 +74,16 @@ typedef enum { ...@@ -74,16 +74,16 @@ typedef enum {
} infinirtStreamCaptureMode_t; } infinirtStreamCaptureMode_t;
__C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode); __INFINI_C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode);
__C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr); __INFINI_C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr);
__C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph); __INFINI_C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph);
__C __export infiniStatus_t infinirtGraphInstantiate( __INFINI_C __export infiniStatus_t infinirtGraphInstantiate(
infinirtGraphExec_t *graph_exec_ptr, infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph, infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr, infinirtGraphNode_t *node_ptr,
char *log_buffer, char *log_buffer,
size_t buffer_size); size_t buffer_size);
__C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec); __INFINI_C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec);
__C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream); __INFINI_C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream);
#endif // __INFINIRT_API_H__ #endif // __INFINIRT_API_H__
...@@ -52,6 +52,7 @@ from infinicore.ops.add_rms_norm import add_rms_norm ...@@ -52,6 +52,7 @@ from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mha_varlen import mha_varlen
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention from infinicore.ops.paged_attention import paged_attention
...@@ -134,6 +135,7 @@ __all__ = [ ...@@ -134,6 +135,7 @@ __all__ = [
"from_list", "from_list",
"from_numpy", "from_numpy",
"from_torch", "from_torch",
"mha_varlen",
"paged_caching", "paged_caching",
"paged_attention", "paged_attention",
"paged_attention_prefill", "paged_attention_prefill",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def mha_varlen(
q: Tensor,
k: Tensor,
v: Tensor,
cum_seqlens_q: Tensor,
cum_seqlens_k: Tensor,
block_table: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.mha_varlen(
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)
_infinicore.mha_varlen_(
out._underlying,
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
return out
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "./metax/infiniccl_metax.h" #include "./metax/infiniccl_metax.h"
#include "./moore/infiniccl_moore.h" #include "./moore/infiniccl_moore.h"
__C infiniStatus_t infinicclCommInitAll( __INFINI_C infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type, infiniDevice_t device_type,
infinicclComm_t *comms, infinicclComm_t *comms,
int ndevice, int ndevice,
...@@ -35,7 +35,7 @@ __C infiniStatus_t infinicclCommInitAll( ...@@ -35,7 +35,7 @@ __C infiniStatus_t infinicclCommInitAll(
#undef COMM_INIT_ALL #undef COMM_INIT_ALL
} }
__C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) { __INFINI_C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
if (comm == nullptr) { if (comm == nullptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -61,7 +61,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) { ...@@ -61,7 +61,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
#undef COMM_DESTROY #undef COMM_DESTROY
} }
__C infiniStatus_t infinicclAllReduce( __INFINI_C infiniStatus_t infinicclAllReduce(
void *sendbuf, void *sendbuf,
void *recvbuf, void *recvbuf,
size_t count, size_t count,
......
#ifdef ENABLE_ATEN
#include "infinicore/adaptor/aten_adaptor.hpp"
namespace infinicore::adaptor {
at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
void *data_ptr = (void *)(t->data());
auto sizes = std::vector<int64_t>(
t->shape().begin(),
t->shape().end());
auto strides = t->strides();
auto dtype = to_at_dtype(t->dtype());
auto device = to_at_device(t->device());
auto deleter_ = [](void * /*unused*/) mutable {
};
at::TensorOptions options = at::TensorOptions()
.dtype(dtype)
.device(device)
.requires_grad(false);
return at::from_blob(
data_ptr,
sizes,
strides,
deleter_,
options);
}
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#endif
} // namespace infinicore::adaptor
#endif // ENABLE_ATEN
#include "infinicore/ops/mha_varlen.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MultiheadAttentionVarlen);
MultiheadAttentionVarlen::MultiheadAttentionVarlen(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
void MultiheadAttentionVarlen::execute(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MultiheadAttentionVarlen,
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
Tensor mha_varlen(
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_varlen_(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
return out;
}
void mha_varlen_(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
MultiheadAttentionVarlen::execute(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
} // namespace infinicore::op
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace infinicore::op::mha_varlen_impl::flashattn {
struct PlannedMeta {
graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table;
int max_seqlen_q, max_seqlen_k;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};
void *plan(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(cum_seqlens_q),
graph::GraphTensor(cum_seqlens_k),
graph::GraphTensor(block_table),
max_seqlen_q,
max_seqlen_k,
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}
void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k);
auto v = infinicore::adaptor::to_aten_tensor(p->v);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto max_seqlen_q = p->max_seqlen_q;
auto max_seqlen_k = p->max_seqlen_k;
auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt;
auto scale = p->scale;
flash::mha_varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_kv,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
0.0,
scale,
false,
true,
-1,
-1,
0.0,
false,
std::nullopt);
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MultiheadAttentionVarlen, &plan, &run, &cleanup);
} // namespace infinicore::op::mha_varlen_impl::flashattn
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ops/linear.hpp" #include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp" #include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mul.hpp" #include "ops/mul.hpp"
#include "ops/paged_attention.hpp" #include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp" #include "ops/paged_attention_prefill.hpp"
...@@ -38,6 +39,7 @@ inline void bind(py::module &m) { ...@@ -38,6 +39,7 @@ inline void bind(py::module &m) {
bind_linear(m); bind_linear(m);
bind_matmul(m); bind_matmul(m);
bind_mul(m); bind_mul(m);
bind_mha_varlen(m);
bind_paged_attention(m); bind_paged_attention(m);
bind_paged_attention_prefill(m); bind_paged_attention_prefill(m);
bind_paged_caching(m); bind_paged_caching(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/mha_varlen.hpp"
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_mha_varlen(Tensor q,
Tensor k,
Tensor v,
Tensor cum_seqlens_q,
Tensor cum_seqlens_k,
Tensor block_table,
int max_seqlen_q,
int max_seqlen_k,
pybind11::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::mha_varlen(
q,
k,
v,
cum_seqlens_q,
cum_seqlens_k,
block_table,
max_seqlen_q,
max_seqlen_k,
alibi_slopes_tensor,
scale);
}
void py_mha_varlen_(Tensor out,
Tensor q,
Tensor k,
Tensor v,
Tensor cum_seqlens_q,
Tensor cum_seqlens_k,
Tensor block_table,
int max_seqlen_q,
int max_seqlen_k,
pybind11::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
op::mha_varlen_(
out,
q,
k,
v,
cum_seqlens_q,
cum_seqlens_k,
block_table,
max_seqlen_q,
max_seqlen_k,
alibi_slopes_tensor,
scale);
}
inline void bind_mha_varlen(py::module &m) {
m.def(
"mha_varlen",
&ops::py_mha_varlen,
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("cum_seqlens_q"),
py::arg("cum_seqlens_k"),
py::arg("block_table"),
py::arg("max_seqlen_q"),
py::arg("max_seqlen_k"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Variable-length multi-head attention.)doc");
m.def(
"mha_varlen_",
&ops::py_mha_varlen_,
py::arg("out"),
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("cum_seqlens_q"),
py::arg("cum_seqlens_k"),
py::arg("block_table"),
py::arg("max_seqlen_q"),
py::arg("max_seqlen_k"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place variable-length multi-head attention.)doc");
}
} // namespace infinicore::ops
...@@ -16,15 +16,15 @@ InfiniOP 是 InfiniCore 下属的统一底层算子框架,为相同算子在 ...@@ -16,15 +16,15 @@ InfiniOP 是 InfiniCore 下属的统一底层算子框架,为相同算子在
typedef struct InfiniopDescriptor *infiniopAddDescriptor_t; typedef struct InfiniopDescriptor *infiniopAddDescriptor_t;
__C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle, __INFINI_C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr, infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c, infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a, infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b); infiniopTensorDescriptor_t b);
__C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size); __INFINI_C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc, __INFINI_C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *c, void *c,
...@@ -32,7 +32,7 @@ InfiniOP 是 InfiniCore 下属的统一底层算子框架,为相同算子在 ...@@ -32,7 +32,7 @@ InfiniOP 是 InfiniCore 下属的统一底层算子框架,为相同算子在
const void *b, const void *b,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc); __INFINI_C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
#endif #endif
``` ```
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "metax/metax_handle.h" #include "metax/metax_handle.h"
#endif #endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { __INFINI_C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
if (handle_ptr == nullptr) { if (handle_ptr == nullptr) {
return INFINI_STATUS_NULL_POINTER; return INFINI_STATUS_NULL_POINTER;
} }
...@@ -79,7 +79,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { ...@@ -79,7 +79,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { __INFINI_C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#define DELETE(CASE, NAMESPACE) \ #define DELETE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "moore/add_moore.h" #include "moore/add_moore.h"
#endif #endif
__C infiniStatus_t infiniopCreateAddDescriptor( __INFINI_C infiniStatus_t infiniopCreateAddDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr, infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
...@@ -77,7 +77,7 @@ __C infiniStatus_t infiniopCreateAddDescriptor( ...@@ -77,7 +77,7 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size) { __INFINI_C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
...@@ -123,7 +123,7 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz ...@@ -123,7 +123,7 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopAdd( __INFINI_C infiniStatus_t infiniopAdd(
infiniopAddDescriptor_t desc, infiniopAddDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
...@@ -177,7 +177,7 @@ __C infiniStatus_t infiniopAdd( ...@@ -177,7 +177,7 @@ __C infiniStatus_t infiniopAdd(
#undef CALCULATE #undef CALCULATE
} }
__C infiniStatus_t __INFINI_C infiniStatus_t
infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) { infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \ #define DELETE(CASE, NAMESPACE) \
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
// #include "kunlun/add_rms_norm_kunlun.h" // #include "kunlun/add_rms_norm_kunlun.h"
#endif #endif
__C infiniStatus_t infiniopCreateAddRMSNormDescriptor( __INFINI_C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr, infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t residual_out_desc,
...@@ -84,7 +84,7 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( ...@@ -84,7 +84,7 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) { __INFINI_C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
...@@ -127,7 +127,7 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript ...@@ -127,7 +127,7 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopAddRMSNorm( __INFINI_C infiniStatus_t infiniopAddRMSNorm(
infiniopAddRMSNormDescriptor_t desc, infiniopAddRMSNormDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
...@@ -178,7 +178,7 @@ __C infiniStatus_t infiniopAddRMSNorm( ...@@ -178,7 +178,7 @@ __C infiniStatus_t infiniopAddRMSNorm(
#undef CALCULATE #undef CALCULATE
} }
__C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) { __INFINI_C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) {
if (desc == nullptr) { if (desc == nullptr) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment