Commit 81e5fe94 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/810 support more ops as graph op

parent 0611cb1b
...@@ -11,7 +11,7 @@ struct PlannedMeta { ...@@ -11,7 +11,7 @@ struct PlannedMeta {
float alpha, beta; float alpha, beta;
}; };
void *plan(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void *plan(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) {
size_t seed = hash_combine(c, a, b); size_t seed = hash_combine(c, a, b);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
......
...@@ -5,23 +5,46 @@ ...@@ -5,23 +5,46 @@
#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include <infiniop.h>
#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \ #define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \
struct __DESC_TYPE__ { \ struct __DESC_TYPE__ { \
infiniop##__OP_NAME__##Descriptor_t desc; \ infiniop##__OP_NAME__##Descriptor_t desc = nullptr; \
Descriptor(infiniop##__OP_NAME__##Descriptor_t desc) : desc(desc) {} \ \
~Descriptor() { \ explicit __DESC_TYPE__(infiniop##__OP_NAME__##Descriptor_t d) \
if (desc != nullptr) { \ : desc(d) {} \
infiniopDestroy##__OP_NAME__##Descriptor(desc); \ \
desc = nullptr; \ /* non-copyable */ \
} \ __DESC_TYPE__(const __DESC_TYPE__ &) = delete; \
} \ __DESC_TYPE__ &operator=(const __DESC_TYPE__ &) = delete; \
}; \ \
\ /* movable */ \
thread_local common::OpCache<size_t, std::shared_ptr<__DESC_TYPE__>> \ __DESC_TYPE__(__DESC_TYPE__ &&other) noexcept \
caches( \ : desc(other.desc) { \
__SIZE__, \ other.desc = nullptr; \
[](std::shared_ptr<__DESC_TYPE__> &desc) { \ } \
desc = nullptr; \ \
__DESC_TYPE__ &operator=(__DESC_TYPE__ &&other) noexcept { \
if (this != &other) { \
if (desc != nullptr) { \
infiniopDestroy##__OP_NAME__##Descriptor(desc); \
} \
desc = other.desc; \
other.desc = nullptr; \
} \
return *this; \
} \
\
~__DESC_TYPE__() { \
if (desc != nullptr) { \
infiniopDestroy##__OP_NAME__##Descriptor(desc); \
} \
} \
}; \
\
thread_local common::OpCache<size_t, std::shared_ptr<__DESC_TYPE__>> \
caches( \
__SIZE__, \
[](std::shared_ptr<__DESC_TYPE__> &desc) { \
desc = nullptr; \
}); });
#define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \ #define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \
......
#include "infinicore/ops/mul.hpp" #include "infinicore/ops/mul.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Mul::schema> &Mul::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Mul);
static common::OpDispatcher<Mul::schema> dispatcher_;
return dispatcher_;
};
void Mul::execute(Tensor c, Tensor a, Tensor b) { Mul::Mul(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device()); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b);
dispatcher().lookup(c->device().getType())(c, a, b); }
void Mul::execute(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Mul, c, a, b);
} }
Tensor mul(Tensor a, Tensor b) { Tensor mul(const Tensor &a, const Tensor &b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
mul_(c, a, b); mul_(c, a, b);
return c; return c;
} }
void mul_(Tensor c, Tensor a, Tensor b) { void mul_(Tensor c, const Tensor &a, const Tensor &b) {
Mul::execute(c, a, b); Mul::execute(c, a, b);
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/mul.hpp" #include "infinicore/ops/mul.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::mul_impl::infiniop { namespace infinicore::op::mul_impl::infiniop {
thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Mul, 100);
100, // capacity
[](infiniopMulDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc)); graph::GraphTensor workspace, c, a, b;
desc = nullptr; };
}
});
void calculate(Tensor c, Tensor a, Tensor b) { void *plan(Tensor c, const Tensor &a, const Tensor &b) {
size_t seed = hash_combine(c, b, a); size_t seed = hash_combine(c, b, a);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, Mul,
seed, c->desc(), a->desc(), b->desc());
auto desc_opt = cache.get(seed); INFINIOP_WORKSPACE_TENSOR(workspace, Mul, descriptor);
infiniopMulDescriptor_t desc = nullptr;
if (!desc_opt) { return new PlannedMeta{
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( descriptor,
context::getInfiniopHandle(device), &desc, graph::GraphTensor(workspace),
c->desc(), a->desc(), b->desc())); graph::GraphTensor(c),
cache.put(seed, desc); graph::GraphTensor(a),
} else { graph::GraphTensor(b)};
desc = *desc_opt; }
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopMul( INFINICORE_CHECK_ERROR(infiniopMul(
desc, workspace->data(), workspace_size, planned->descriptor->desc,
c->data(), a->data(), b->data(), context::getStream())); planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
planned->a->data(),
planned->b->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Mul, &plan, &run, &cleanup);
Mul::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::mul_impl::infiniop } // namespace infinicore::op::mul_impl::infiniop
#include "infinicore/ops/paged_attention.hpp" #include "infinicore/ops/paged_attention.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PagedAttention);
static common::OpDispatcher<PagedAttention::schema> dispatcher_;
return dispatcher_;
};
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) { PagedAttention::PagedAttention(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens);
infinicore::context::setDevice(out->device()); INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
}
void PagedAttention::execute(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
PagedAttention,
out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
} }
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) { Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
return out; return out;
} }
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) { void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention.hpp" #include "infinicore/ops/paged_attention.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::paged_attention_impl::infiniop { namespace infinicore::op::paged_attention_impl::infiniop {
thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PagedAttention, 100);
100, // capacity
[](infiniopPagedAttentionDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionDescriptor(desc)); graph::GraphTensor workspace, out, q, k_cache, v_cache, block_tables, cache_lens;
desc = nullptr; std::optional<graph::GraphTensor> alibi_slopes;
} float scale;
}); };
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) { void *plan(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); const Tensor &block_tables, const Tensor &cache_lens,
std::optional<Tensor> alibi_slopes, float scale) {
auto device = context::getDevice(); size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes);
auto &cache = caches.getCache(device); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, PagedAttention,
auto desc_opt = cache.get(seed); seed,
infiniopPagedAttentionDescriptor_t desc = nullptr; out->desc(), q->desc(), k_cache->desc(), v_cache->desc(),
block_tables->desc(), cache_lens->desc(),
if (!desc_opt) { alibi_slopes ? alibi_slopes.value()->desc() : nullptr,
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor( scale);
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), kv_lens->desc(), INFINIOP_WORKSPACE_TENSOR(workspace, PagedAttention, descriptor);
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale)); return new PlannedMeta{
cache.put(seed, desc); descriptor,
} else { graph::GraphTensor(workspace),
desc = *desc_opt; graph::GraphTensor(out),
} graph::GraphTensor(q),
graph::GraphTensor(k_cache),
size_t workspace_size = 0; graph::GraphTensor(v_cache),
INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); graph::GraphTensor(block_tables),
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); graph::GraphTensor(cache_lens),
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
INFINICORE_CHECK_ERROR(infiniopPagedAttention( scale};
desc, workspace->data(), workspace_size, }
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), kv_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, void run(void *planned_meta) {
context::getStream())); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(
infiniopPagedAttention(
p->descriptor->desc,
p->workspace->data(),
p->workspace->numel(),
p->out->data(),
p->q->data(),
p->k_cache->data(),
p->v_cache->data(),
p->block_tables->data(),
p->cache_lens->data(),
p->alibi_slopes.has_value() ? p->alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PagedAttention, &plan, &run, &cleanup);
PagedAttention::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::paged_attention_impl::infiniop } // namespace infinicore::op::paged_attention_impl::infiniop
#include "infinicore/ops/paged_caching.hpp" #include "infinicore/ops/paged_caching.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PagedCaching);
static common::OpDispatcher<PagedCaching::schema> dispatcher_;
return dispatcher_;
};
void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { PagedCaching::PagedCaching(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping);
infinicore::context::setDevice(k_cache->device()); INFINICORE_GRAPH_OP_DISPATCH(k->device().getType(), k_cache, v_cache, k, v, slot_mapping);
dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping); }
void PagedCaching::execute(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(PagedCaching, k_cache, v_cache, k, v, slot_mapping);
} }
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) {
PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping); PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping);
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_caching.hpp" #include "infinicore/ops/paged_caching.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::paged_caching_impl::infiniop { namespace infinicore::op::paged_caching_impl::infiniop {
thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PagedCaching, 100);
100, // capacity
[](infiniopPagedCachingDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyPagedCachingDescriptor(desc));
desc = nullptr; graph::GraphTensor workspace, k_cache, v_cache, k, v, slot_mapping;
} };
});
void *plan(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) {
void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { size_t key = hash_combine(k_cache, v_cache, k, v, slot_mapping);
size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto device = context::getDevice(); Descriptor, descriptor, PagedCaching,
auto &cache = caches.getCache(device); key, k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc());
auto desc_opt = cache.get(seed); INFINIOP_WORKSPACE_TENSOR(workspace, PagedCaching, descriptor);
infiniopPagedCachingDescriptor_t desc = nullptr;
return new PlannedMeta{
if (!desc_opt) { descriptor,
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor( graph::GraphTensor(workspace),
context::getInfiniopHandle(device), &desc, graph::GraphTensor(k_cache),
k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc())); graph::GraphTensor(v_cache),
cache.put(seed, desc); graph::GraphTensor(k),
} else { graph::GraphTensor(v),
desc = *desc_opt; graph::GraphTensor(slot_mapping)};
} }
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size)); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(
INFINICORE_CHECK_ERROR(infiniopPagedCaching( infiniopPagedCaching(
desc, workspace->data(), workspace_size, p->descriptor->desc,
k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream())); p->workspace->data(),
p->workspace->numel(),
p->k_cache->data(),
p->v_cache->data(),
p->k->data(),
p->v->data(),
p->slot_mapping->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PagedCaching, &plan, &run, &cleanup);
PagedCaching::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::paged_caching_impl::infiniop } // namespace infinicore::op::paged_caching_impl::infiniop
...@@ -3,24 +3,30 @@ ...@@ -3,24 +3,30 @@
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<Rearrange::schema> &Rearrange::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Rearrange);
static common::OpDispatcher<Rearrange::schema> dispatcher_;
return dispatcher_;
};
void Rearrange::execute(Tensor y, Tensor x) { Rearrange::Rearrange(Tensor y, const Tensor &x) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x);
infinicore::context::setDevice(y->device()); INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x);
dispatcher().lookup(y->device().getType())(y, x);
} }
Tensor rearrange(Tensor x) { void Rearrange::execute(Tensor y, const Tensor &x) {
auto op = std::make_shared<Rearrange>(y, x);
if (context::isGraphRecording()) {
context::addGraphOperator(op);
} else {
op->run();
}
}
Tensor rearrange(const Tensor &x) {
auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); auto y = Tensor::empty(x->shape(), x->dtype(), x->device());
rearrange_(y, x); rearrange_(y, x);
return y; return y;
} }
void rearrange_(Tensor y, Tensor x) { void rearrange_(Tensor y, const Tensor &x) {
Rearrange::execute(y, x); Rearrange::execute(y, x);
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/rearrange.hpp" #include "infinicore/ops/rearrange.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::rearrange_impl::infiniop { namespace infinicore::op::rearrange_impl::infiniop {
thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Rearrange, 100);
100, // capacity
[](infiniopRearrangeDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyRearrangeDescriptor(desc)); graph::GraphTensor y, x;
desc = nullptr; };
}
});
void calculate(Tensor y, Tensor x) { void *plan(Tensor y, const Tensor &x) {
size_t seed = hash_combine(y, x); size_t seed = hash_combine(y, x);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, Rearrange,
seed, y->desc(),
x->desc());
auto desc_opt = cache.get(seed); return new PlannedMeta{
infiniopRearrangeDescriptor_t desc = nullptr; descriptor,
graph::GraphTensor(y),
graph::GraphTensor(x)};
}
if (!desc_opt) { void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(device), &desc, y->desc(), x->desc())); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
INFINICORE_CHECK_ERROR( INFINICORE_CHECK_ERROR(
infiniopRearrange( infiniopRearrange(
desc, planned->descriptor->desc,
y->data(), planned->y->data(),
x->data(), planned->x->data(),
context::getStream())); context::getStream()));
} }
static bool registered = []() { void cleanup(void **planned_meta_ptr) {
Rearrange::dispatcher().registerAll(&calculate, false); delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
return true; *planned_meta_ptr = nullptr;
}(); }
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Rearrange, &plan, &run, &cleanup);
} // namespace infinicore::op::rearrange_impl::infiniop } // namespace infinicore::op::rearrange_impl::infiniop
#include "infinicore/ops/rms_norm.hpp" #include "infinicore/ops/rms_norm.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RMSNorm);
common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() { RMSNorm::RMSNorm(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) {
static common::OpDispatcher<RMSNorm::schema> dispatcher_;
return dispatcher_;
};
void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x, weight); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x, weight);
infinicore::context::setDevice(y->device()); INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x, weight, epsilon);
dispatcher().lookup(y->device().getType())(y, x, weight, epsilon); }
void RMSNorm::execute(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(RMSNorm, y, x, weight, epsilon);
} }
Tensor rms_norm(Tensor x, Tensor weight, float epsilon) { Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon) {
auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); auto y = Tensor::empty(x->shape(), x->dtype(), x->device());
rms_norm_(y, x, weight, epsilon); rms_norm_(y, x, weight, epsilon);
return y; return y;
} }
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon) { void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) {
RMSNorm::execute(y, x, weight, epsilon); RMSNorm::execute(y, x, weight, epsilon);
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/rms_norm.hpp" #include "infinicore/ops/rms_norm.hpp"
#include <infiniop.h>
namespace infinicore::op::rms_norm_impl::infiniop { #include "../infiniop_impl.hpp"
thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches( namespace infinicore::op::rms_norm_impl::infiniop {
100, // capacity
[](infiniopRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyRMSNormDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RMSNorm, 100);
size_t seed = hash_combine(y, x, weight, epsilon);
auto device = context::getDevice(); struct PlannedMeta {
auto &cache = caches.getCache(device); std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, y, x, weight;
};
auto desc_opt = cache.get(seed); void *plan(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) {
infiniopRMSNormDescriptor_t desc = nullptr; size_t seed = hash_combine(y, x, weight, epsilon);
if (!desc_opt) { INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( Descriptor, descriptor, RMSNorm,
context::getInfiniopHandle(device), &desc, seed, y->desc(),
y->desc(), x->desc(), weight->desc(), epsilon)); x->desc(),
cache.put(seed, desc); weight->desc(),
} else { epsilon);
desc = *desc_opt;
} INFINIOP_WORKSPACE_TENSOR(workspace, RMSNorm, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(y),
graph::GraphTensor(x),
graph::GraphTensor(weight)};
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(
infiniopRMSNorm(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->y->data(),
planned->x->data(),
planned->weight->data(),
context::getStream()));
}
INFINICORE_CHECK_ERROR(infiniopRMSNorm( void cleanup(void **planned_meta_ptr) {
desc, workspace->data(), workspace_size, delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
y->data(), x->data(), weight->data(), context::getStream())); *planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RMSNorm, &plan, &run, &cleanup);
RMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::rms_norm_impl::infiniop } // namespace infinicore::op::rms_norm_impl::infiniop
#include "infinicore/ops/rope.hpp" #include "infinicore/ops/rope.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
#include "infinicore/context/context.hpp"
#include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RoPE);
static common::OpDispatcher<RoPE::schema> dispatcher_;
return dispatcher_;
};
void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { RoPE::RoPE(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x_out, x, pos, sin_table, cos_table); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x_out, x, pos, sin_table, cos_table);
infinicore::context::setDevice(x_out->device()); INFINICORE_GRAPH_OP_DISPATCH(x_out->device().getType(), x_out, x, pos, sin_table, cos_table, algo);
auto device_type = x_out->device().getType(); }
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No RoPE implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(x_out, x, pos, sin_table, cos_table, algo); void RoPE::execute(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(RoPE, x_out, x, pos, sin_table, cos_table, algo);
} }
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { void rope_(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo) {
RoPE::execute(x_out, x, pos, sin_table, cos_table, algo); RoPE::execute(x_out, x, pos, sin_table, cos_table, algo);
} }
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { Tensor rope(const Tensor &x,
Shape shape = x->shape(); const Tensor &pos,
auto x_out = Tensor::empty(shape, x->dtype(), x->device()); const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo) {
auto x_out = Tensor::empty(x->shape(), x->dtype(), x->device());
rope_(x_out, x, pos, sin_table, cos_table, algo); rope_(x_out, x, pos, sin_table, cos_table, algo);
return x_out; return x_out;
} }
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/rope.hpp" #include "infinicore/ops/rope.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::rope_impl::infiniop { namespace infinicore::op::rope_impl::infiniop {
thread_local common::OpCache<size_t, infiniopRoPEDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RoPE, 100);
100, // capacity
[](infiniopRoPEDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyRoPEDescriptor(desc)); graph::GraphTensor workspace;
desc = nullptr; graph::GraphTensor x_out;
} graph::GraphTensor x;
}); graph::GraphTensor pos;
graph::GraphTensor sin;
graph::GraphTensor cos;
};
void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) { static infiniopRoPEAlgo_t to_infiniop_algo(infinicore::nn::RoPE::Algo algo) {
// Convert infinicore::nn::RoPE::Algo to infiniopRoPEAlgo_t
infiniopRoPEAlgo_t infiniop_algo;
switch (algo) { switch (algo) {
case infinicore::nn::RoPE::Algo::GPT_J: case infinicore::nn::RoPE::Algo::GPT_J:
infiniop_algo = INFINIOP_ROPE_ALGO_GPT_J; return INFINIOP_ROPE_ALGO_GPT_J;
break;
case infinicore::nn::RoPE::Algo::GPT_NEOX: case infinicore::nn::RoPE::Algo::GPT_NEOX:
infiniop_algo = INFINIOP_ROPE_ALGO_GPT_NEOX; return INFINIOP_ROPE_ALGO_GPT_NEOX;
break;
default: default:
throw std::runtime_error("Unsupported RoPE algorithm: " + std::to_string(static_cast<int>(algo))); throw std::runtime_error("Unsupported RoPE algorithm");
} }
}
// Create hash key for descriptor caching void *plan(Tensor x_out,
size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache); const Tensor &x,
hash_combine(key, std::hash<int>()(static_cast<int>(infiniop_algo))); const Tensor &pos,
const Tensor &sin,
const Tensor &cos,
infinicore::nn::RoPE::Algo algo) {
auto infiniop_algo = to_infiniop_algo(algo);
size_t key = hash_combine(x_out, x, pos, sin, cos, static_cast<int>(infiniop_algo));
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, RoPE, key, x_out->desc(),
x->desc(),
pos->desc(),
sin->desc(),
cos->desc(),
infiniop_algo);
auto desc_opt = cache.get(key); INFINIOP_WORKSPACE_TENSOR(workspace, RoPE, descriptor);
infiniopRoPEDescriptor_t desc = nullptr; return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(x_out),
graph::GraphTensor(x),
graph::GraphTensor(pos),
graph::GraphTensor(sin),
graph::GraphTensor(cos)};
}
if (!desc_opt) { void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
context::getInfiniopHandle(device), &desc,
x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo));
cache.put(key, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0; INFINICORE_CHECK_ERROR(
INFINICORE_CHECK_ERROR(infiniopGetRoPEWorkspaceSize(desc, &workspace_size)); infiniopRoPE(
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); p->descriptor->desc,
p->workspace->data(),
p->workspace->numel(),
p->x_out->data(),
p->x->data(),
p->pos->data(),
p->sin->data(),
p->cos->data(),
context::getStream()));
}
// InfiniOP reads from x and writes to x_out (handles copying internally) void cleanup(void **planned_meta_ptr) {
INFINICORE_CHECK_ERROR(infiniopRoPE( delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
desc, workspace->data(), workspace_size, *planned_meta_ptr = nullptr;
x_out->data(), x->data(), pos->data(),
sin_cache->data(), cos_cache->data(), context::getStream()));
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RoPE, &plan, &run, &cleanup);
RoPE::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::rope_impl::infiniop } // namespace infinicore::op::rope_impl::infiniop
#include "infinicore/ops/swiglu.hpp" #include "infinicore/ops/swiglu.hpp"
#include "../../utils.hpp" #include "../../utils.hpp"
#include <stdexcept>
namespace infinicore::op { namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SwiGLU);
common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() { SwiGLU::SwiGLU(Tensor c, const Tensor &a, const Tensor &b) {
static common::OpDispatcher<SwiGLU::schema> dispatcher_;
return dispatcher_;
};
void SwiGLU::execute(Tensor c, Tensor a, Tensor b) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device()); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b);
auto device_type = c->device().getType(); }
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No SwiGLU implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(c, a, b); void SwiGLU::execute(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SwiGLU, c, a, b);
} }
Tensor swiglu(Tensor a, Tensor b) { Tensor swiglu(const Tensor &a, const Tensor &b) {
Shape shape = a->shape(); auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
auto c = Tensor::empty(shape, a->dtype(), a->device());
swiglu_(c, a, b); swiglu_(c, a, b);
return c; return c;
} }
void swiglu_(Tensor c, Tensor a, Tensor b) { void swiglu_(Tensor c, const Tensor &a, const Tensor &b) {
SwiGLU::execute(c, a, b); SwiGLU::execute(c, a, b);
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/swiglu.hpp" #include "infinicore/ops/swiglu.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::swiglu_impl::infiniop { namespace infinicore::op::swiglu_impl::infiniop {
thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SwiGLU, 100);
100, // capacity
[](infiniopSwiGLUDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroySwiGLUDescriptor(desc)); graph::GraphTensor workspace;
desc = nullptr; graph::GraphTensor c;
} graph::GraphTensor a;
}); graph::GraphTensor b;
};
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a); void *plan(Tensor c, const Tensor &a, const Tensor &b) {
size_t key = hash_combine(c, a, b);
auto device = context::getDevice();
auto &cache = caches.getCache(device); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, SwiGLU,
auto desc_opt = cache.get(seed); key, c->desc(), a->desc(), b->desc());
infiniopSwiGLUDescriptor_t desc = nullptr;
INFINIOP_WORKSPACE_TENSOR(workspace, SwiGLU, descriptor);
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( return new PlannedMeta{
context::getInfiniopHandle(device), &desc, descriptor,
c->desc(), a->desc(), b->desc())); graph::GraphTensor(workspace),
cache.put(seed, desc); graph::GraphTensor(c),
} else { graph::GraphTensor(a),
desc = *desc_opt; graph::GraphTensor(b)};
} }
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size)); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(
INFINICORE_CHECK_ERROR(infiniopSwiGLU( infiniopSwiGLU(
desc, workspace->data(), workspace_size, p->descriptor->desc,
c->data(), a->data(), b->data(), context::getStream())); p->workspace->data(),
p->workspace->numel(),
p->c->data(),
p->a->data(),
p->b->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SwiGLU, &plan, &run, &cleanup);
SwiGLU::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::swiglu_impl::infiniop } // namespace infinicore::op::swiglu_impl::infiniop
...@@ -49,6 +49,7 @@ inline struct SpdlogInitializer { ...@@ -49,6 +49,7 @@ inline struct SpdlogInitializer {
+ ":" + std::to_string(__LINE__) + "."); \ + ":" + std::to_string(__LINE__) + "."); \
} \ } \
} \ } \
infinicore::context::setDevice((FIRST___)->device()); \
} while (0) } while (0)
#define INFINICORE_ASSERT(CONDITION__) \ #define INFINICORE_ASSERT(CONDITION__) \
......
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework import BaseOperatorTest, GenericTestRunner, TensorSpec, TestCase
from framework.tensor import TensorInitializer
import infinicore
# Test cases format: (nlayers, batch_size, hidden_size, nhead, nkvhead, dim, seqlen, past_seqlen, max_seqlen)
_TEST_CASES_DATA = [
(28, 1, 3584, 28, 28, 128, 1, 256, 512),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-4, "rtol": 5e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def parse_test_cases():
cases = []
for (
nlayers,
batch_size,
hidden_size,
nhead,
nkvhead,
dim,
seqlen,
past_seqlen,
max_seqlen,
) in _TEST_CASES_DATA:
for dtype in _TENSOR_DTYPES:
tol = _TOLERANCE_MAP[dtype]
hidden_states = TensorSpec.from_tensor(
(batch_size, seqlen, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
pos_ids = TensorSpec.from_tensor(
(batch_size, seqlen),
dtype=infinicore.int64,
init_mode=TensorInitializer.RANDINT,
low=0,
high=max_seqlen,
)
k_cache = TensorSpec.from_tensor(
(nlayers, batch_size, nkvhead, max_seqlen, dim),
dtype=dtype,
scale=1e-1,
bias=-5e-2,
)
v_cache = TensorSpec.from_tensor(
(nlayers, batch_size, nkvhead, max_seqlen, dim),
dtype=dtype,
scale=1e-1,
bias=-5e-2,
)
q_proj_w = TensorSpec.from_tensor(
(nhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
k_proj_w = TensorSpec.from_tensor(
(nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
v_proj_w = TensorSpec.from_tensor(
(nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
o_proj_w = TensorSpec.from_tensor(
(hidden_size, nhead * dim), dtype=dtype, scale=1e-1, bias=-5e-2
)
norm_w = TensorSpec.from_tensor(
(hidden_size,), dtype=dtype, scale=1e-1, bias=-5e-2
)
sin_table = TensorSpec.from_tensor(
(max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2
)
cos_table = TensorSpec.from_tensor(
(max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2
)
# Out-of-place
cases.append(
TestCase(
inputs=[
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tol,
description="Graph",
)
)
return cases
def torch_rope(
q: torch.Tensor,
k: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
pos_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
q, k: [B, H, S, D]
sin, cos: [max_S, D//2]
pos_ids: [B, S]
"""
def rotate_half(x: torch.Tensor) -> torch.Tensor:
# x: [..., head_dim]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
B, H, S, D = q.shape
assert D % 2 == 0
# Gather sin/cos by position
# -> [B, S, D//2]
sin = sin[pos_ids]
cos = cos[pos_ids]
# Expand to broadcast over heads
# -> [B, 1, S, D//2]
sin = sin.unsqueeze(1)
cos = cos.unsqueeze(1)
# Interleave to full dim
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
# Apply RoPE
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class OpTest(BaseOperatorTest):
"""Test Operator Graph"""
def __init__(self):
super().__init__("Graph")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(
self,
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
**kwargs,
):
B, S, D = hidden_states.shape
for layer in range(nlayers):
# ---- RMSNorm ----
var = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(var + 1e-5) * norm_w
# ---- QKV projection ----
q = hidden_states @ q_proj_w.T
k = hidden_states @ k_proj_w.T
v = hidden_states @ v_proj_w.T
q = q.view(B, S, nhead, dim).transpose(1, 2) # [B,H,S,Dh]
k = k.view(B, S, nkvhead, dim).transpose(1, 2)
v = v.view(B, S, nkvhead, dim).transpose(1, 2)
# ---- RoPE ----
q, k = torch_rope(
q,
k,
sin_table,
cos_table,
pos_ids,
)
# ---- KV cache update ----
k_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = k
v_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = v
K = k_cache[layer, :, :, 0 : past_seqlen + S, :]
V = v_cache[layer, :, :, 0 : past_seqlen + S, :]
# ---- Scaled Dot Product Attention (fused) ----
def scaled_dot_product_attention(
query, key, value, is_causal=False, enable_gqa=False
) -> torch.Tensor:
S, L = query.size(-2), key.size(-2)
scale_factor = query.size(-1) ** -0.5
attn_bias = torch.zeros(S, L, dtype=query.dtype, device=query.device)
if is_causal:
mask = torch.tril(attn_bias + 1, diagonal=-1).flip(dims=[-2, -1])
attn_bias = torch.where(mask == 1, -torch.inf, 0.0)
if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(
query.size(-3) // value.size(-3), -3
)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
attn_out = scaled_dot_product_attention(
q,
K,
V,
is_causal=True,
enable_gqa=True,
) # [B,H,S,Dh]
# ---- Output projection ----
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(B, S, nhead * dim)
hidden_states = attn_out @ o_proj_w.T
return hidden_states
def infinicore_operator(
self,
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
**kwargs,
):
"""Record graph and run"""
input_hidden_states = hidden_states
B, S, D = input_hidden_states.shape
infinicore.start_graph_recording()
for layer in range(nlayers):
hidden_states = infinicore.nn.functional.rms_norm(
hidden_states, norm_w.shape, norm_w, 1e-5
)
q = infinicore.nn.functional.linear(hidden_states, q_proj_w)
k = infinicore.nn.functional.linear(hidden_states, k_proj_w)
v = infinicore.nn.functional.linear(hidden_states, v_proj_w)
q = q.view((B, S, nhead, dim))
k = k.view((B, S, nkvhead, dim))
v = v.view((B, S, nkvhead, dim))
q = infinicore.nn.functional.rope(
q,
pos_ids,
sin_table,
cos_table,
infinicore.nn.functional.RopeAlgo.GPT_J,
)
k = infinicore.nn.functional.rope(
k,
pos_ids,
sin_table,
cos_table,
infinicore.nn.functional.RopeAlgo.GPT_J,
)
# [B, KVH, total_len, D]
full_k = (
k_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S)
)
full_v = (
v_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S)
)
full_k.narrow(2, past_seqlen, S).copy_(k.permute((0, 2, 1, 3)))
full_v.narrow(2, past_seqlen, S).copy_(v.permute((0, 2, 1, 3)))
G = nhead // nkvhead
L = past_seqlen + S
full_q = (
q.permute((0, 2, 1, 3)).contiguous().view((B * nkvhead, G * S, dim))
)
full_k = full_k.view((B * nkvhead, L, dim))
full_v = full_v.view((B * nkvhead, L, dim))
attn_score = infinicore.matmul(
full_q, full_k.permute((0, 2, 1)), alpha=dim**-0.5
)
# [B * H, S, total_len]
attn_score = attn_score.view((B * nhead, S, L))
infinicore.nn.functional.causal_softmax(attn_score, out=attn_score)
attn_out = infinicore.matmul(attn_score, full_v)
attn_out = (
attn_out.view((B, nhead, S, dim))
.permute((0, 2, 1, 3))
.contiguous()
.view((B, S, nhead * dim))
)
hidden_states = infinicore.nn.functional.linear(attn_out, o_proj_w)
op_graph = infinicore.stop_graph_recording()
op_graph.run()
return hidden_states
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (in_shape, proj_w_shape)
_TEST_CASES_DATA = [
((32, 4096), (4096, 4096)),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def parse_test_cases():
cases = []
for in_shape, proj_w_shape in _TEST_CASES_DATA:
for dtype in _TENSOR_DTYPES:
tol = _TOLERANCE_MAP[dtype]
in_spec = TensorSpec.from_tensor(in_shape, dtype=dtype)
proj_w_spec = TensorSpec.from_tensor(proj_w_shape, dtype=dtype)
temp_spec = TensorSpec.from_tensor(in_shape, dtype=dtype)
# Out-of-place
cases.append(
TestCase(
inputs=[in_spec, proj_w_spec, temp_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tol,
description="Graph",
)
)
return cases
class OpTest(BaseOperatorTest):
"""Test Operator Graph"""
def __init__(self):
super().__init__("Graph")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
a = args[0]
b = args[1]
return torch.matmul(a, b)
def infinicore_operator(self, *args, **kwargs):
"""Record graph and run"""
a = args[0]
b = args[1]
temp_a = args[2]
infinicore.start_graph_recording()
c = infinicore.matmul(temp_a, b)
op_graph = infinicore.stop_graph_recording()
temp_a.copy_(a)
op_graph.run()
return c
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment