Commit 97eced0e authored by wooway777's avatar wooway777
Browse files

issue/923 - ninetoothed kv caching for nv, il, mtx

parent 5614e1be
......@@ -6,6 +6,7 @@
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
......
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);
void kv_caching_(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths);
} // namespace infinicore::op
......@@ -13,6 +13,7 @@
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/kv_caching.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
......
#ifndef __INFINIOP_KV_CACHING_API_H__
#define __INFINIOP_KV_CACHING_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t;
__C __export infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths);
__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream);
__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc);
#endif
......@@ -45,6 +45,7 @@ from infinicore.dtype import (
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
......@@ -115,6 +116,7 @@ __all__ = [
"add_rms_norm",
"add_rms_norm_",
"attention",
"kv_caching",
"matmul",
"mul",
"narrow",
......
from infinicore.lib import _infinicore
def kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
_infinicore.kv_caching_(
k_cache._underlying,
v_cache._underlying,
k._underlying,
v._underlying,
past_kv_lengths._underlying,
)
return k_cache, v_cache
#include "infinicore/ops/kv_caching.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(KVCaching);
KVCaching::KVCaching(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths);
INFINICORE_GRAPH_OP_DISPATCH(k_cache->device().getType(),
k_cache,
v_cache,
k,
v,
past_kv_lengths);
}
void KVCaching::execute(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(KVCaching,
k_cache,
v_cache,
k,
v,
past_kv_lengths);
}
void kv_caching_(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths) {
KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths);
}
} // namespace infinicore::op
#include "../infiniop_impl.hpp"
#include "infinicore/ops/kv_caching.hpp"
namespace infinicore::op::kv_caching_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, KVCaching, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, k_cache, v_cache, k, v, past_kv_lengths;
};
void *plan(Tensor k_cache,
Tensor v_cache,
const Tensor &k,
const Tensor &v,
const Tensor &past_kv_lengths) {
size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, KVCaching,
seed, k_cache->desc(), v_cache->desc(),
k->desc(), v->desc(), past_kv_lengths->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, KVCaching, descriptor);
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(past_kv_lengths)};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopKVCaching(
planned->descriptor->desc,
nullptr, 0,
planned->k_cache->data(),
planned->v_cache->data(),
planned->k->data(),
planned->v->data(),
planned->past_kv_lengths->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(KVCaching, &plan, &run, cleanup);
} // namespace infinicore::op::kv_caching_impl::infiniop
......@@ -8,6 +8,7 @@
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
......@@ -31,6 +32,7 @@ inline void bind(py::module &m) {
bind_attention(m);
bind_causal_softmax(m);
bind_flash_attention(m);
bind_kv_caching(m);
bind_linear(m);
bind_matmul(m);
bind_mul(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/kv_caching.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_kv_caching(py::module &m) {
m.def("kv_caching_",
&op::kv_caching_,
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("k"),
py::arg("v"),
py::arg("past_kv_lengths"),
R"doc(In-place Key-Value Caching.
Updates the KV cache in-place with new key and value tensors.
Args:
k_cache: Key cache tensor to update in-place
v_cache: Value cache tensor to update in-place
k: New key tensor to append
v: New value tensor to append
past_kv_lengths: Tensor containing current sequence lengths for each batch
)doc");
}
} // namespace infinicore::ops
import ninetoothed
from . import kv_caching
import infiniop.ninetoothed.build
def build():
dtype_values = (
ninetoothed.float16,
ninetoothed.bfloat16,
ninetoothed.float32,
)
constexpr_param_grid = {
"emb_dim": (1, 16, 32, 64, 128, 256),
"dtype": dtype_values,
"block_size_m": (64,),
"block_size_n": (64,),
}
infiniop.ninetoothed.build.build(
kv_caching.premake,
constexpr_param_grid,
caller="cuda",
op_name="kv_caching",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/kv_caching.h"
#include "../../../ninetoothed/utils.h"
namespace op::kv_caching::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(
infiniopHandle_t handle,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id},
k_cache_shape_{k_cache_desc->shape()},
k_cache_strides_{k_cache_desc->strides()},
v_cache_shape_{v_cache_desc->shape()},
v_cache_strides_{v_cache_desc->strides()},
k_shape_{k_desc->shape()},
k_strides_{k_desc->strides()},
v_shape_{v_desc->shape()},
v_strides_{v_desc->strides()},
past_kv_lengths_shape_{past_kv_lengths_desc->shape()},
past_kv_lengths_strides_{past_kv_lengths_desc->strides()},
dtype_{k_desc->dtype()} {}
~Descriptor() = default;
size_t get_workspace_size() const { return 0; };
static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
*desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths};
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t calculate(
void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream) const {
auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}};
auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}};
auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}};
auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}};
auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}};
if (launch_kv_caching(stream,
k_cache_nt,
v_cache_nt,
k_nt,
v_nt,
past_kv_lengths_nt,
k_shape_[3],
dtype_,
64, 64)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> k_cache_shape_;
std::vector<Stride> k_cache_strides_;
std::vector<Size> v_cache_shape_;
std::vector<Stride> v_cache_strides_;
std::vector<Size> k_shape_;
std::vector<Stride> k_strides_;
std::vector<Size> v_shape_;
std::vector<Stride> v_strides_;
std::vector<Size> past_kv_lengths_shape_;
std::vector<Stride> past_kv_lengths_strides_;
infiniDtype_t dtype_;
};
} // namespace op::kv_caching::ninetoothed
#endif // KV_CACHING_H
import functools
import ninetoothed
from ninetoothed import Tensor
def arrangement(
k_cache,
v_cache,
k,
v,
past_lengths,
block_size_m=ninetoothed.block_size(),
block_size_n=ninetoothed.block_size(),
):
k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
past_lengths_arranged = (
past_lengths.tile((1,))
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.unsqueeze(4)
.expand((-1, *k_arranged.shape))
)
return (
k_cache_arranged,
v_cache_arranged,
k_arranged,
v_arranged,
past_lengths_arranged,
)
def application(k_cache, v_cache, k, v, past_lengths):
pos = past_lengths
for i in range(k.shape[-2]):
k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0]
v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0]
def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None):
arrangement_ = functools.partial(
arrangement, block_size_m=block_size_m, block_size_n=block_size_n
)
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256})
tensors = (
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(1, dtype=ninetoothed.int64),
)
if emb_dim is not None:
for tensor in tensors:
tensor.shape = tensor.shape[:-1] + (emb_dim,)
return arrangement_, application, tensors
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
__C infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::kv_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache, \
v_cache, \
k, \
v, \
past_kv_lengths)
switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
infiniopKVCachingDescriptor_t desc,
size_t *size) {
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET_SIZE
}
__C infiniStatus_t infiniopKVCaching(
infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream)
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyKVCachingDescriptor(
infiniopKVCachingDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
......@@ -342,7 +342,10 @@ class BaseOperatorTest(ABC):
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
if comparison_target == i or (
isinstance(comparison_target, (list, tuple))
and i in comparison_target
):
cloned_inp = clone_torch_tensor(inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
......@@ -508,7 +511,9 @@ class BaseOperatorTest(ABC):
# Handle multiple outputs comparison
# Determine what to compare based on comparison_target
if comparison_target is None:
if comparison_target is None or isinstance(
comparison_target, (list, tuple)
):
# Compare return values (out-of-place multiple outputs)
torch_comparison = torch_result
infini_comparison = infini_result
......@@ -573,7 +578,9 @@ class BaseOperatorTest(ABC):
# ==========================================================================
else:
# Determine comparison targets for single output
if comparison_target is None:
if comparison_target is None or isinstance(
comparison_target, (list, tuple)
):
# Compare return values (out-of-place)
torch_comparison = torch_result
infini_comparison = infini_result
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TensorInitializer,
TestCase,
GenericTestRunner,
is_broadcast,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape (bs, nkvh, seq_len, hd), strides)
_TEST_CASES_DATA = [
((1, 1, 8, 1), None),
((1, 8, 32, 32), None),
((8, 8, 64, 32), None),
((1, 32, 8, 64), (32768, 1024, 64, 1)),
((4, 8, 32, 16), (65536, 8192, 256, 16)),
((8, 16, 64, 128), (8388608, 524288, 8192, 1)),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 0},
infinicore.bfloat16: {"atol": 0, "rtol": 0},
infinicore.float32: {"atol": 0, "rtol": 0},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
test_cases = []
for data in _TEST_CASES_DATA:
import random
cache_shape = data[0]
kv_shape = (
cache_shape[0],
cache_shape[1],
random.randint(1, cache_shape[2]),
cache_shape[3],
)
past_shape = (cache_shape[0],)
strides = data[1]
past_length = random.randint(0, cache_shape[2] - kv_shape[2])
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype)
kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype)
past_kv_lengths_spec = TensorSpec.from_tensor(
past_shape,
None,
infinicore.int64,
init_mode=TensorInitializer.RANDINT,
low=past_length,
high=past_length + 1,
)
test_cases.append(
TestCase(
inputs=[
cache_spec,
cache_spec,
kv_spec,
kv_spec,
past_kv_lengths_spec,
],
kwargs={},
output_spec=None,
comparison_target=[0, 1],
tolerance=tolerance,
description=f"KV Caching",
)
)
return test_cases
def torch_kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
batch_size, num_kv_heads, _, head_dim = k_cache.shape
seq_len = k.shape[2]
for b in range(batch_size):
past_len = past_kv_lengths[b].item()
for h in range(num_kv_heads):
k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :]
v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :]
return k_cache, v_cache
def infinicore_kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
infinicore.kv_caching(k_cache, v_cache, k, v, past_kv_lengths)
return k_cache, v_cache
class OpTest(BaseOperatorTest):
def __init__(self):
super().__init__("KV Caching")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
return torch_kv_caching(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
return infinicore_kv_caching(*args, **kwargs)
def main():
test_runner = GenericTestRunner(OpTest)
test_runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment