Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor mul(Tensor a, Tensor b);
void mul_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &);
Tensor mul(const Tensor &a, const Tensor &b);
void mul_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>, float);
Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(PerChannelQuantI8, const Tensor &, Tensor, Tensor);
void per_channel_quant_i8_(const Tensor &x, Tensor x_packed, Tensor x_scale);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Rearrange {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor y, Tensor x);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor rearrange(Tensor x);
void rearrange_(Tensor y, Tensor x);
INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &);
Tensor rearrange(const Tensor &x);
void rearrange_(Tensor y, const Tensor &x);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class RMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float);
Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "../nn/rope.hpp"
#include "../tensor.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class RoPE {
public:
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
static common::OpDispatcher<schema> &dispatcher();
};
// Internal function
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
INFINICORE_GRAPH_OP_CLASS(RoPE, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
// Internal
void rope_(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);
// Public API
Tensor rope(const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);
// Public API that uses infinicore::nn::RoPE::Algo
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(I8Gemm, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>);
void scaled_mm_i8_(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(SiluAndMul, Tensor, const Tensor &);
Tensor silu_and_mul(const Tensor &x);
void silu_and_mul_(Tensor out, const Tensor &x);
} // namespace infinicore::op
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "../tensor.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class SwiGLU {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor swiglu(Tensor a, Tensor b);
void swiglu_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(SwiGLU, Tensor, const Tensor &, const Tensor &);
Tensor swiglu(const Tensor &a, const Tensor &b);
void swiglu_(Tensor c, const Tensor &a, const Tensor &b);
} // namespace infinicore::op
#pragma once
#include "quantization/awq.hpp"
#include "quantization/base_quantization.hpp"
#include "quantization/compressed_tensors.hpp"
#include "quantization/none_quantizaiton.hpp"
#include "quantization/quantization_scheme.hpp"
#pragma once
#include "base_quantization.hpp"
namespace infinicore::quantization {
class AWQ : public BaseQuantization {
// This is a temporary class that currently only returns AWQ_W4A16.
// Future enhancements should parse quant_config to extract detailed quantization
// information and support multiple quantization schemes.
public:
explicit AWQ(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {};
infinicore::quantization::QuantScheme
get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::AWQ_W4A16;
};
};
} // namespace infinicore::quantization
#pragma once
#include "nlohmann/json.hpp"
#include "quantization_scheme.hpp"
namespace infinicore::quantization {
class BaseQuantization {
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
public:
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
virtual ~BaseQuantization() = default;
virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
protected:
nlohmann::json quant_config_;
};
} // namespace infinicore::quantization
#pragma once
#include "base_quantization.hpp"
namespace infinicore::quantization {
class CompressedTensors : public BaseQuantization {
// This is a temporary class that currently only returns COMPRESSED_TENSOR_W8A8I8.
// Future enhancements should parse quant_config to extract detailed quantization
// information and support multiple quantization schemes.
public:
explicit CompressedTensors(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {};
infinicore::quantization::QuantScheme
get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8;
};
};
} // namespace infinicore::quantization
#pragma once
#include "base_quantization.hpp"
namespace infinicore::quantization {
class NoneQuantization : public BaseQuantization {
// This is a temporary class that currently only returns COMPRESSED_TENSOR_W8A8I8.
// Future enhancements should parse quant_config to extract detailed quantization
// information and support multiple quantization schemes.
public:
explicit NoneQuantization(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {};
infinicore::quantization::QuantScheme
get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::NONE;
};
};
} // namespace infinicore::quantization
// quant.hpp
#pragma once
namespace infinicore::quantization {
enum class QuantScheme {
NONE,
COMPRESSED_TENSOR_W8A8I8,
AWQ_W4A16,
};
} // namespace infinicore::quantization
......@@ -9,8 +9,12 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/embedding.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/int8_gemm.h"
#include "infiniop/ops/kv_caching.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
......@@ -19,6 +23,7 @@
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/quant/per_channel_quant_int8.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
......@@ -26,6 +31,7 @@
#include "infiniop/ops/rope.h"
#include "infiniop/ops/sigmoid.h"
#include "infiniop/ops/silu.h"
#include "infiniop/ops/silu_and_mul.h"
#include "infiniop/ops/softmax.h"
#include "infiniop/ops/softplus.h"
#include "infiniop/ops/sub.h"
......
......@@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc);
float epsilon);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
......@@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
void *workspace,
size_t workspace_size,
void *y,
void *residual_out,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
......
#ifndef __INFINIOP_EMBEDDING_API_H__
#define __INFINIOP_EMBEDDING_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;
__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
infiniopHandle_t handle,
infiniopEmbeddingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc);
__C __export infiniStatus_t infiniopEmbedding(
infiniopEmbeddingDescriptor_t desc,
void *output,
const void *input,
const void *weight,
void *stream);
__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
infiniopEmbeddingDescriptor_t desc);
#endif
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;
__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal);
__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);
__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream);
__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc);
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment