Unverified Commit dce99862 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1053 from InfiniTensor/issue/1033xmake

Issue/1033 patch aten and fa adaptations
parents 8d99a8f5 d6e44e84
......@@ -5,15 +5,15 @@
typedef struct InfiniopDescriptor *infiniopSubDescriptor_t;
__C __export infiniStatus_t infiniopCreateSubDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateSubDescriptor(infiniopHandle_t handle,
infiniopSubDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);
__C __export infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
......@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
const void *b,
void *stream);
__C __export infiniStatus_t infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc);
#endif
......@@ -5,15 +5,15 @@
typedef struct InfiniopDescriptor *infiniopSwiGLUDescriptor_t;
__C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle,
infiniopSwiGLUDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
......@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void const *b,
void *stream);
__C __export infiniStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc);
#endif
......@@ -5,20 +5,20 @@
typedef struct InfiniopDescriptor *infiniopTanhDescriptor_t;
__C __export infiniStatus_t infiniopCreateTanhDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateTanhDescriptor(infiniopHandle_t handle,
infiniopTanhDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input);
__C __export infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTanh(infiniopTanhDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopTanh(infiniopTanhDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *output,
const void *input,
void *stream);
__C __export infiniStatus_t infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc);
#endif
......@@ -5,14 +5,14 @@
typedef struct InfiniopDescriptor *infiniopTopkrouterDescriptor_t;
__C __export infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle,
infiniopTopkrouterDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc);
__C __export infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *values,
......@@ -23,6 +23,6 @@ __C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t de
const size_t topk,
void *stream);
__C __export infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc);
#endif
......@@ -5,13 +5,13 @@
typedef struct InfiniopDescriptor *infiniopTopksoftmaxDescriptor_t;
__C __export infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
infiniopTopksoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc);
__C __export infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *values,
......@@ -21,6 +21,6 @@ __C __export infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t
const int norm,
void *stream);
__C __export infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc);
#endif
......@@ -5,20 +5,20 @@
typedef struct InfiniopDescriptor *infiniopZerosDescriptor_t;
__C __export infiniStatus_t infiniopCreateZerosDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateZerosDescriptor(infiniopHandle_t handle,
infiniopZerosDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x);
__C __export infiniStatus_t infiniopGetZerosWorkspaceSize(infiniopZerosDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetZerosWorkspaceSize(infiniopZerosDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopZeros(infiniopZerosDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopZeros(infiniopZerosDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);
__C __export infiniStatus_t infiniopDestroyZerosDescriptor(infiniopZerosDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroyZerosDescriptor(infiniopZerosDescriptor_t desc);
#endif
......@@ -7,8 +7,8 @@ struct InfiniopTensorDescriptor;
typedef struct InfiniopTensorDescriptor *infiniopTensorDescriptor_t;
__C __export infiniStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescriptor_t *desc_ptr, size_t ndim, const size_t *shape, const ptrdiff_t *strides, infiniDtype_t dtype);
__INFINI_C __export infiniStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescriptor_t *desc_ptr, size_t ndim, const size_t *shape, const ptrdiff_t *strides, infiniDtype_t dtype);
__C __export infiniStatus_t infiniopDestroyTensorDescriptor(infiniopTensorDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroyTensorDescriptor(infiniopTensorDescriptor_t desc);
#endif // __INFINIOP_TENSOR_DESCRIPTOR__
......@@ -10,20 +10,20 @@ typedef void *infinirtGraph_t;
typedef void *infinirtGraphNode_t;
typedef void *infinirtGraphExec_t;
__C __export infiniStatus_t infinirtInit();
__INFINI_C __export infiniStatus_t infinirtInit();
// Device
__C __export infiniStatus_t infinirtGetAllDeviceCount(int *count_array);
__C __export infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count);
__C __export infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id);
__C __export infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_ptr);
__C __export infiniStatus_t infinirtDeviceSynchronize();
__INFINI_C __export infiniStatus_t infinirtGetAllDeviceCount(int *count_array);
__INFINI_C __export infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count);
__INFINI_C __export infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id);
__INFINI_C __export infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_ptr);
__INFINI_C __export infiniStatus_t infinirtDeviceSynchronize();
// Stream
__C __export infiniStatus_t infinirtStreamCreate(infinirtStream_t *stream_ptr);
__C __export infiniStatus_t infinirtStreamDestroy(infinirtStream_t stream);
__C __export infiniStatus_t infinirtStreamSynchronize(infinirtStream_t stream);
__C __export infiniStatus_t infinirtStreamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
__INFINI_C __export infiniStatus_t infinirtStreamCreate(infinirtStream_t *stream_ptr);
__INFINI_C __export infiniStatus_t infinirtStreamDestroy(infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtStreamSynchronize(infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtStreamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
// Event
typedef enum {
......@@ -38,13 +38,13 @@ typedef enum {
INFINIRT_EVENT_BLOCKING_SYNC = 0x2, // Event uses blocking synchronization
} infinirtEventFlags_t;
__C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr);
__C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags);
__C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream);
__C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr);
__C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event);
__C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end);
__INFINI_C __export infiniStatus_t infinirtEventCreate(infinirtEvent_t *event_ptr);
__INFINI_C __export infiniStatus_t infinirtEventCreateWithFlags(infinirtEvent_t *event_ptr, uint32_t flags);
__INFINI_C __export infiniStatus_t infinirtEventRecord(infinirtEvent_t event, infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtEventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr);
__INFINI_C __export infiniStatus_t infinirtEventSynchronize(infinirtEvent_t event);
__INFINI_C __export infiniStatus_t infinirtEventDestroy(infinirtEvent_t event);
__INFINI_C __export infiniStatus_t infinirtEventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end);
// Memory
typedef enum {
......@@ -54,17 +54,17 @@ typedef enum {
INFINIRT_MEMCPY_D2D = 3,
} infinirtMemcpyKind_t;
__C __export infiniStatus_t infinirtMalloc(void **p_ptr, size_t size);
__C __export infiniStatus_t infinirtMallocHost(void **p_ptr, size_t size);
__C __export infiniStatus_t infinirtFree(void *ptr);
__C __export infiniStatus_t infinirtFreeHost(void *ptr);
__INFINI_C __export infiniStatus_t infinirtMalloc(void **p_ptr, size_t size);
__INFINI_C __export infiniStatus_t infinirtMallocHost(void **p_ptr, size_t size);
__INFINI_C __export infiniStatus_t infinirtFree(void *ptr);
__INFINI_C __export infiniStatus_t infinirtFreeHost(void *ptr);
__C __export infiniStatus_t infinirtMemcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind);
__C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtMemcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind);
__INFINI_C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream);
// Stream-ordered memory
__C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream);
__C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream);
// Graph
typedef enum {
......@@ -74,16 +74,16 @@ typedef enum {
} infinirtStreamCaptureMode_t;
__C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode);
__C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr);
__C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph);
__C __export infiniStatus_t infinirtGraphInstantiate(
__INFINI_C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode);
__INFINI_C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr);
__INFINI_C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph);
__INFINI_C __export infiniStatus_t infinirtGraphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size);
__C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec);
__C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream);
__INFINI_C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec);
__INFINI_C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream);
#endif // __INFINIRT_API_H__
......@@ -52,6 +52,7 @@ from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mha_varlen import mha_varlen
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention
......@@ -134,6 +135,7 @@ __all__ = [
"from_list",
"from_numpy",
"from_torch",
"mha_varlen",
"paged_caching",
"paged_attention",
"paged_attention_prefill",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def mha_varlen(
q: Tensor,
k: Tensor,
v: Tensor,
cum_seqlens_q: Tensor,
cum_seqlens_k: Tensor,
block_table: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.mha_varlen(
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)
_infinicore.mha_varlen_(
out._underlying,
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
return out
......@@ -7,7 +7,7 @@
#include "./metax/infiniccl_metax.h"
#include "./moore/infiniccl_moore.h"
__C infiniStatus_t infinicclCommInitAll(
__INFINI_C infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
......@@ -35,7 +35,7 @@ __C infiniStatus_t infinicclCommInitAll(
#undef COMM_INIT_ALL
}
__C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
__INFINI_C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
if (comm == nullptr) {
return INFINI_STATUS_SUCCESS;
}
......@@ -61,7 +61,7 @@ __C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
#undef COMM_DESTROY
}
__C infiniStatus_t infinicclAllReduce(
__INFINI_C infiniStatus_t infinicclAllReduce(
void *sendbuf,
void *recvbuf,
size_t count,
......
#ifdef ENABLE_ATEN
#include "infinicore/adaptor/aten_adaptor.hpp"
namespace infinicore::adaptor {
at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
void *data_ptr = (void *)(t->data());
auto sizes = std::vector<int64_t>(
t->shape().begin(),
t->shape().end());
auto strides = t->strides();
auto dtype = to_at_dtype(t->dtype());
auto device = to_at_device(t->device());
auto deleter_ = [](void * /*unused*/) mutable {
};
at::TensorOptions options = at::TensorOptions()
.dtype(dtype)
.device(device)
.requires_grad(false);
return at::from_blob(
data_ptr,
sizes,
strides,
deleter_,
options);
}
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#endif
} // namespace infinicore::adaptor
#endif // ENABLE_ATEN
#include "infinicore/ops/mha_varlen.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MultiheadAttentionVarlen);
MultiheadAttentionVarlen::MultiheadAttentionVarlen(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
void MultiheadAttentionVarlen::execute(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MultiheadAttentionVarlen,
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
Tensor mha_varlen(
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_varlen_(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
return out;
}
void mha_varlen_(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
MultiheadAttentionVarlen::execute(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
} // namespace infinicore::op
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace infinicore::op::mha_varlen_impl::flashattn {
struct PlannedMeta {
graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table;
int max_seqlen_q, max_seqlen_k;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};
void *plan(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(cum_seqlens_q),
graph::GraphTensor(cum_seqlens_k),
graph::GraphTensor(block_table),
max_seqlen_q,
max_seqlen_k,
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}
void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k);
auto v = infinicore::adaptor::to_aten_tensor(p->v);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto max_seqlen_q = p->max_seqlen_q;
auto max_seqlen_k = p->max_seqlen_k;
auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt;
auto scale = p->scale;
flash::mha_varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_kv,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
0.0,
scale,
false,
true,
-1,
-1,
0.0,
false,
std::nullopt);
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MultiheadAttentionVarlen, &plan, &run, &cleanup);
} // namespace infinicore::op::mha_varlen_impl::flashattn
......@@ -12,6 +12,7 @@
#include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
......@@ -38,6 +39,7 @@ inline void bind(py::module &m) {
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_mha_varlen(m);
bind_paged_attention(m);
bind_paged_attention_prefill(m);
bind_paged_caching(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/mha_varlen.hpp"
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_mha_varlen(Tensor q,
Tensor k,
Tensor v,
Tensor cum_seqlens_q,
Tensor cum_seqlens_k,
Tensor block_table,
int max_seqlen_q,
int max_seqlen_k,
pybind11::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::mha_varlen(
q,
k,
v,
cum_seqlens_q,
cum_seqlens_k,
block_table,
max_seqlen_q,
max_seqlen_k,
alibi_slopes_tensor,
scale);
}
void py_mha_varlen_(Tensor out,
Tensor q,
Tensor k,
Tensor v,
Tensor cum_seqlens_q,
Tensor cum_seqlens_k,
Tensor block_table,
int max_seqlen_q,
int max_seqlen_k,
pybind11::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
op::mha_varlen_(
out,
q,
k,
v,
cum_seqlens_q,
cum_seqlens_k,
block_table,
max_seqlen_q,
max_seqlen_k,
alibi_slopes_tensor,
scale);
}
inline void bind_mha_varlen(py::module &m) {
m.def(
"mha_varlen",
&ops::py_mha_varlen,
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("cum_seqlens_q"),
py::arg("cum_seqlens_k"),
py::arg("block_table"),
py::arg("max_seqlen_q"),
py::arg("max_seqlen_k"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Variable-length multi-head attention.)doc");
m.def(
"mha_varlen_",
&ops::py_mha_varlen_,
py::arg("out"),
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("cum_seqlens_q"),
py::arg("cum_seqlens_k"),
py::arg("block_table"),
py::arg("max_seqlen_q"),
py::arg("max_seqlen_k"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place variable-length multi-head attention.)doc");
}
} // namespace infinicore::ops
......@@ -16,15 +16,15 @@ InfiniOP 是 InfiniCore 下属的统一底层算子框架,为相同算子在
typedef struct InfiniopDescriptor *infiniopAddDescriptor_t;
__C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
__INFINI_C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);
__C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);
__INFINI_C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
__INFINI_C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
......@@ -32,7 +32,7 @@ InfiniOP 是 InfiniCore 下属的统一底层算子框架,为相同算子在
const void *b,
void *stream);
__C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
__INFINI_C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
#endif
```
......
......@@ -24,7 +24,7 @@
#include "metax/metax_handle.h"
#endif
__C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
__INFINI_C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
if (handle_ptr == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
......@@ -79,7 +79,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
#undef CREATE
}
__C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
__INFINI_C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
......
......@@ -21,7 +21,7 @@
#include "moore/add_moore.h"
#endif
__C infiniStatus_t infiniopCreateAddDescriptor(
__INFINI_C infiniStatus_t infiniopCreateAddDescriptor(
infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
......@@ -77,7 +77,7 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#undef CREATE
}
__C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size) {
__INFINI_C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
......@@ -123,7 +123,7 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopAdd(
__INFINI_C infiniStatus_t infiniopAdd(
infiniopAddDescriptor_t desc,
void *workspace,
size_t workspace_size,
......@@ -177,7 +177,7 @@ __C infiniStatus_t infiniopAdd(
#undef CALCULATE
}
__C infiniStatus_t
__INFINI_C infiniStatus_t
infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
......
......@@ -27,7 +27,7 @@
// #include "kunlun/add_rms_norm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
__INFINI_C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t residual_out_desc,
......@@ -84,7 +84,7 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
#undef CREATE
}
__C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) {
__INFINI_C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
......@@ -127,7 +127,7 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopAddRMSNorm(
__INFINI_C infiniStatus_t infiniopAddRMSNorm(
infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
......@@ -178,7 +178,7 @@ __C infiniStatus_t infiniopAddRMSNorm(
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) {
__INFINI_C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) {
if (desc == nullptr) {
return INFINI_STATUS_SUCCESS;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment