Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
97eced0e
Commit
97eced0e
authored
Jan 26, 2026
by
wooway777
Browse files
issue/923 - ninetoothed kv caching for nv, il, mtx
parent
5614e1be
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
681 additions
and
3 deletions
+681
-3
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+1
-0
include/infinicore/ops/kv_caching.hpp
include/infinicore/ops/kv_caching.hpp
+16
-0
include/infiniop.h
include/infiniop.h
+1
-0
include/infiniop/ops/kv_caching.h
include/infiniop/ops/kv_caching.h
+31
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+2
-0
python/infinicore/ops/kv_caching.py
python/infinicore/ops/kv_caching.py
+13
-0
src/infinicore/ops/kv_caching/kv_caching.cc
src/infinicore/ops/kv_caching/kv_caching.cc
+42
-0
src/infinicore/ops/kv_caching/kv_caching_infiniop.cc
src/infinicore/ops/kv_caching/kv_caching_infiniop.cc
+60
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+2
-0
src/infinicore/pybind11/ops/kv_caching.hpp
src/infinicore/pybind11/ops/kv_caching.hpp
+32
-0
src/infiniop/ops/kv_caching/ninetoothed/build.py
src/infiniop/ops/kv_caching/ninetoothed/build.py
+27
-0
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h
+101
-0
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py
+66
-0
src/infiniop/ops/kv_caching/operator.cc
src/infiniop/ops/kv_caching/operator.cc
+143
-0
test/infinicore/framework/base.py
test/infinicore/framework/base.py
+10
-3
test/infinicore/ops/kv_caching.py
test/infinicore/ops/kv_caching.py
+134
-0
No files found.
include/infinicore/ops.hpp
View file @
97eced0e
...
...
@@ -6,6 +6,7 @@
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
...
...
include/infinicore/ops/kv_caching.hpp
0 → 100644
View file @
97eced0e
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_CLASS
(
KVCaching
,
Tensor
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
);
void
kv_caching_
(
Tensor
k_cache
,
Tensor
v_cache
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
past_kv_lengths
);
}
// namespace infinicore::op
include/infiniop.h
View file @
97eced0e
...
...
@@ -13,6 +13,7 @@
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/kv_caching.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
...
...
include/infiniop/ops/kv_caching.h
0 → 100644
View file @
97eced0e
#ifndef __INFINIOP_KV_CACHING_API_H__
#define __INFINIOP_KV_CACHING_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopKVCachingDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateKVCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopKVCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
);
__C
__export
infiniStatus_t
infiniopGetKVCachingWorkspaceSize
(
infiniopKVCachingDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopKVCaching
(
infiniopKVCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyKVCachingDescriptor
(
infiniopKVCachingDescriptor_t
desc
);
#endif
python/infinicore/__init__.py
View file @
97eced0e
...
...
@@ -45,6 +45,7 @@ from infinicore.dtype import (
from
infinicore.ops.add
import
add
from
infinicore.ops.add_rms_norm
import
add_rms_norm
from
infinicore.ops.attention
import
attention
from
infinicore.ops.kv_caching
import
kv_caching
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.narrow
import
narrow
...
...
@@ -115,6 +116,7 @@ __all__ = [
"add_rms_norm"
,
"add_rms_norm_"
,
"attention"
,
"kv_caching"
,
"matmul"
,
"mul"
,
"narrow"
,
...
...
python/infinicore/ops/kv_caching.py
0 → 100644
View file @
97eced0e
from
infinicore.lib
import
_infinicore
def
kv_caching
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
):
_infinicore
.
kv_caching_
(
k_cache
.
_underlying
,
v_cache
.
_underlying
,
k
.
_underlying
,
v
.
_underlying
,
past_kv_lengths
.
_underlying
,
)
return
k_cache
,
v_cache
src/infinicore/ops/kv_caching/kv_caching.cc
0 → 100644
View file @
97eced0e
#include "infinicore/ops/kv_caching.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
KVCaching
);
KVCaching
::
KVCaching
(
Tensor
k_cache
,
Tensor
v_cache
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
past_kv_lengths
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
INFINICORE_GRAPH_OP_DISPATCH
(
k_cache
->
device
().
getType
(),
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
}
void
KVCaching
::
execute
(
Tensor
k_cache
,
Tensor
v_cache
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
past_kv_lengths
)
{
INFINICORE_GRAPH_OP_RECORD_OR_RUN
(
KVCaching
,
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
}
void
kv_caching_
(
Tensor
k_cache
,
Tensor
v_cache
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
past_kv_lengths
)
{
KVCaching
::
execute
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
}
}
// namespace infinicore::op
src/infinicore/ops/kv_caching/kv_caching_infiniop.cc
0 → 100644
View file @
97eced0e
#include "../infiniop_impl.hpp"
#include "infinicore/ops/kv_caching.hpp"
namespace
infinicore
::
op
::
kv_caching_impl
::
infiniop
{
INFINIOP_CACHABLE_DESCRIPTOR
(
Descriptor
,
KVCaching
,
100
);
struct
PlannedMeta
{
std
::
shared_ptr
<
Descriptor
>
descriptor
;
graph
::
GraphTensor
workspace
,
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
;
};
void
*
plan
(
Tensor
k_cache
,
Tensor
v_cache
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
past_kv_lengths
)
{
size_t
seed
=
hash_combine
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE
(
Descriptor
,
descriptor
,
KVCaching
,
seed
,
k_cache
->
desc
(),
v_cache
->
desc
(),
k
->
desc
(),
v
->
desc
(),
past_kv_lengths
->
desc
());
INFINIOP_WORKSPACE_TENSOR
(
workspace
,
KVCaching
,
descriptor
);
auto
planned
=
new
PlannedMeta
{
descriptor
,
graph
::
GraphTensor
(
workspace
),
graph
::
GraphTensor
(
k_cache
),
graph
::
GraphTensor
(
v_cache
),
graph
::
GraphTensor
(
k
),
graph
::
GraphTensor
(
v
),
graph
::
GraphTensor
(
past_kv_lengths
)};
return
planned
;
}
void
run
(
void
*
planned_meta
)
{
auto
planned
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
INFINICORE_CHECK_ERROR
(
infiniopKVCaching
(
planned
->
descriptor
->
desc
,
nullptr
,
0
,
planned
->
k_cache
->
data
(),
planned
->
v_cache
->
data
(),
planned
->
k
->
data
(),
planned
->
v
->
data
(),
planned
->
past_kv_lengths
->
data
(),
context
::
getStream
()));
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE
(
KVCaching
,
&
plan
,
&
run
,
cleanup
);
}
// namespace infinicore::op::kv_caching_impl::infiniop
src/infinicore/pybind11/ops.hpp
View file @
97eced0e
...
...
@@ -8,6 +8,7 @@
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
...
...
@@ -31,6 +32,7 @@ inline void bind(py::module &m) {
bind_attention
(
m
);
bind_causal_softmax
(
m
);
bind_flash_attention
(
m
);
bind_kv_caching
(
m
);
bind_linear
(
m
);
bind_matmul
(
m
);
bind_mul
(
m
);
...
...
src/infinicore/pybind11/ops/kv_caching.hpp
0 → 100644
View file @
97eced0e
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/kv_caching.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
inline
void
bind_kv_caching
(
py
::
module
&
m
)
{
m
.
def
(
"kv_caching_"
,
&
op
::
kv_caching_
,
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"past_kv_lengths"
),
R"doc(In-place Key-Value Caching.
Updates the KV cache in-place with new key and value tensors.
Args:
k_cache: Key cache tensor to update in-place
v_cache: Value cache tensor to update in-place
k: New key tensor to append
v: New value tensor to append
past_kv_lengths: Tensor containing current sequence lengths for each batch
)doc"
);
}
}
// namespace infinicore::ops
src/infiniop/ops/kv_caching/ninetoothed/build.py
0 → 100644
View file @
97eced0e
import
ninetoothed
from
.
import
kv_caching
import
infiniop.ninetoothed.build
def
build
():
dtype_values
=
(
ninetoothed
.
float16
,
ninetoothed
.
bfloat16
,
ninetoothed
.
float32
,
)
constexpr_param_grid
=
{
"emb_dim"
:
(
1
,
16
,
32
,
64
,
128
,
256
),
"dtype"
:
dtype_values
,
"block_size_m"
:
(
64
,),
"block_size_n"
:
(
64
,),
}
infiniop
.
ninetoothed
.
build
.
build
(
kv_caching
.
premake
,
constexpr_param_grid
,
caller
=
"cuda"
,
op_name
=
"kv_caching"
,
output_dir
=
infiniop
.
ninetoothed
.
build
.
BUILD_DIRECTORY_PATH
,
)
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h
0 → 100644
View file @
97eced0e
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/kv_caching.h"
#include "../../../ninetoothed/utils.h"
namespace
op
::
kv_caching
::
ninetoothed
{
class
Descriptor
final
:
public
InfiniopDescriptor
{
public:
Descriptor
(
infiniopHandle_t
handle
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
past_kv_lengths_desc
)
:
InfiniopDescriptor
{
handle
->
device
,
handle
->
device_id
},
k_cache_shape_
{
k_cache_desc
->
shape
()},
k_cache_strides_
{
k_cache_desc
->
strides
()},
v_cache_shape_
{
v_cache_desc
->
shape
()},
v_cache_strides_
{
v_cache_desc
->
strides
()},
k_shape_
{
k_desc
->
shape
()},
k_strides_
{
k_desc
->
strides
()},
v_shape_
{
v_desc
->
shape
()},
v_strides_
{
v_desc
->
strides
()},
past_kv_lengths_shape_
{
past_kv_lengths_desc
->
shape
()},
past_kv_lengths_strides_
{
past_kv_lengths_desc
->
strides
()},
dtype_
{
k_desc
->
dtype
()}
{}
~
Descriptor
()
=
default
;
size_t
get_workspace_size
()
const
{
return
0
;
};
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
*
desc_ptr
=
new
Descriptor
{
handle
,
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
};
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream
)
const
{
auto
k_cache_nt
{
::
ninetoothed
::
Tensor
{
k_cache
,
k_cache_shape_
,
k_cache_strides_
}};
auto
v_cache_nt
{
::
ninetoothed
::
Tensor
{
v_cache
,
v_cache_shape_
,
v_cache_strides_
}};
auto
k_nt
{
::
ninetoothed
::
Tensor
{
k
,
k_shape_
,
k_strides_
}};
auto
v_nt
{
::
ninetoothed
::
Tensor
{
v
,
v_shape_
,
v_strides_
}};
auto
past_kv_lengths_nt
{
::
ninetoothed
::
Tensor
{
past_kv_lengths
,
past_kv_lengths_shape_
,
past_kv_lengths_strides_
}};
if
(
launch_kv_caching
(
stream
,
k_cache_nt
,
v_cache_nt
,
k_nt
,
v_nt
,
past_kv_lengths_nt
,
k_shape_
[
3
],
dtype_
,
64
,
64
))
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
private:
using
Size
=
::
ninetoothed
::
Tensor
<>::
Size
;
using
Stride
=
::
ninetoothed
::
Tensor
<>::
Stride
;
std
::
vector
<
Size
>
k_cache_shape_
;
std
::
vector
<
Stride
>
k_cache_strides_
;
std
::
vector
<
Size
>
v_cache_shape_
;
std
::
vector
<
Stride
>
v_cache_strides_
;
std
::
vector
<
Size
>
k_shape_
;
std
::
vector
<
Stride
>
k_strides_
;
std
::
vector
<
Size
>
v_shape_
;
std
::
vector
<
Stride
>
v_strides_
;
std
::
vector
<
Size
>
past_kv_lengths_shape_
;
std
::
vector
<
Stride
>
past_kv_lengths_strides_
;
infiniDtype_t
dtype_
;
};
}
// namespace op::kv_caching::ninetoothed
#endif // KV_CACHING_H
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py
0 → 100644
View file @
97eced0e
import
functools
import
ninetoothed
from
ninetoothed
import
Tensor
def
arrangement
(
k_cache
,
v_cache
,
k
,
v
,
past_lengths
,
block_size_m
=
ninetoothed
.
block_size
(),
block_size_n
=
ninetoothed
.
block_size
(),
):
k_cache_arranged
=
k_cache
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
v_cache_arranged
=
v_cache
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
k_arranged
=
k
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
v_arranged
=
v
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
past_lengths_arranged
=
(
past_lengths
.
tile
((
1
,))
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
unsqueeze
(
3
)
.
unsqueeze
(
4
)
.
expand
((
-
1
,
*
k_arranged
.
shape
))
)
return
(
k_cache_arranged
,
v_cache_arranged
,
k_arranged
,
v_arranged
,
past_lengths_arranged
,
)
def
application
(
k_cache
,
v_cache
,
k
,
v
,
past_lengths
):
pos
=
past_lengths
for
i
in
range
(
k
.
shape
[
-
2
]):
k_cache
[
0
,
0
,
pos
+
i
,
0
]
=
k
[
0
,
0
,
i
,
0
]
v_cache
[
0
,
0
,
pos
+
i
,
0
]
=
v
[
0
,
0
,
i
,
0
]
def
premake
(
emb_dim
=
None
,
dtype
=
None
,
block_size_m
=
None
,
block_size_n
=
None
):
arrangement_
=
functools
.
partial
(
arrangement
,
block_size_m
=
block_size_m
,
block_size_n
=
block_size_n
)
shape_options
=
(
None
,
None
,
None
,
{
"constexpr"
:
True
,
"upper_bound"
:
256
})
tensors
=
(
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
1
,
dtype
=
ninetoothed
.
int64
),
)
if
emb_dim
is
not
None
:
for
tensor
in
tensors
:
tensor
.
shape
=
tensor
.
shape
[:
-
1
]
+
(
emb_dim
,)
return
arrangement_
,
application
,
tensors
src/infiniop/ops/kv_caching/operator.cc
0 → 100644
View file @
97eced0e
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
__C
infiniStatus_t
infiniopCreateKVCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopKVCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::kv_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache, \
v_cache, \
k, \
v, \
past_kv_lengths)
switch
(
handle
->
device
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
CREATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetKVCachingWorkspaceSize
(
infiniopKVCachingDescriptor_t
desc
,
size_t
*
size
)
{
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET_SIZE
}
__C
infiniStatus_t
infiniopKVCaching
(
infiniopKVCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream)
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyKVCachingDescriptor
(
infiniopKVCachingDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
DELETE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
test/infinicore/framework/base.py
View file @
97eced0e
...
...
@@ -342,7 +342,10 @@ class BaseOperatorTest(ABC):
for
i
,
inp
in
enumerate
(
inputs
):
if
isinstance
(
inp
,
torch
.
Tensor
):
# Clone only if this input will be used for comparison
if
comparison_target
==
i
:
if
comparison_target
==
i
or
(
isinstance
(
comparison_target
,
(
list
,
tuple
))
and
i
in
comparison_target
):
cloned_inp
=
clone_torch_tensor
(
inp
)
infini_tensor
=
infinicore_tensor_from_torch
(
cloned_inp
)
cloned_tensors
.
append
(
cloned_inp
)
...
...
@@ -508,7 +511,9 @@ class BaseOperatorTest(ABC):
# Handle multiple outputs comparison
# Determine what to compare based on comparison_target
if
comparison_target
is
None
:
if
comparison_target
is
None
or
isinstance
(
comparison_target
,
(
list
,
tuple
)
):
# Compare return values (out-of-place multiple outputs)
torch_comparison
=
torch_result
infini_comparison
=
infini_result
...
...
@@ -573,7 +578,9 @@ class BaseOperatorTest(ABC):
# ==========================================================================
else
:
# Determine comparison targets for single output
if
comparison_target
is
None
:
if
comparison_target
is
None
or
isinstance
(
comparison_target
,
(
list
,
tuple
)
):
# Compare return values (out-of-place)
torch_comparison
=
torch_result
infini_comparison
=
infini_result
...
...
test/infinicore/ops/kv_caching.py
0 → 100644
View file @
97eced0e
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
import
torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TensorInitializer
,
TestCase
,
GenericTestRunner
,
is_broadcast
,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape (bs, nkvh, seq_len, hd), strides)
_TEST_CASES_DATA
=
[
((
1
,
1
,
8
,
1
),
None
),
((
1
,
8
,
32
,
32
),
None
),
((
8
,
8
,
64
,
32
),
None
),
((
1
,
32
,
8
,
64
),
(
32768
,
1024
,
64
,
1
)),
((
4
,
8
,
32
,
16
),
(
65536
,
8192
,
256
,
16
)),
((
8
,
16
,
64
,
128
),
(
8388608
,
524288
,
8192
,
1
)),
]
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
0
,
"rtol"
:
0
},
infinicore
.
bfloat16
:
{
"atol"
:
0
,
"rtol"
:
0
},
infinicore
.
float32
:
{
"atol"
:
0
,
"rtol"
:
0
},
}
# Data types to test
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
def
parse_test_cases
():
test_cases
=
[]
for
data
in
_TEST_CASES_DATA
:
import
random
cache_shape
=
data
[
0
]
kv_shape
=
(
cache_shape
[
0
],
cache_shape
[
1
],
random
.
randint
(
1
,
cache_shape
[
2
]),
cache_shape
[
3
],
)
past_shape
=
(
cache_shape
[
0
],)
strides
=
data
[
1
]
past_length
=
random
.
randint
(
0
,
cache_shape
[
2
]
-
kv_shape
[
2
])
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
,
{
"atol"
:
0
,
"rtol"
:
0
})
cache_spec
=
TensorSpec
.
from_tensor
(
cache_shape
,
strides
,
dtype
)
kv_spec
=
TensorSpec
.
from_tensor
(
kv_shape
,
None
,
dtype
)
past_kv_lengths_spec
=
TensorSpec
.
from_tensor
(
past_shape
,
None
,
infinicore
.
int64
,
init_mode
=
TensorInitializer
.
RANDINT
,
low
=
past_length
,
high
=
past_length
+
1
,
)
test_cases
.
append
(
TestCase
(
inputs
=
[
cache_spec
,
cache_spec
,
kv_spec
,
kv_spec
,
past_kv_lengths_spec
,
],
kwargs
=
{},
output_spec
=
None
,
comparison_target
=
[
0
,
1
],
tolerance
=
tolerance
,
description
=
f
"KV Caching"
,
)
)
return
test_cases
def
torch_kv_caching
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
):
batch_size
,
num_kv_heads
,
_
,
head_dim
=
k_cache
.
shape
seq_len
=
k
.
shape
[
2
]
for
b
in
range
(
batch_size
):
past_len
=
past_kv_lengths
[
b
].
item
()
for
h
in
range
(
num_kv_heads
):
k_cache
[
b
,
h
,
past_len
:
past_len
+
seq_len
,
:]
=
k
[
b
,
h
,
:,
:]
v_cache
[
b
,
h
,
past_len
:
past_len
+
seq_len
,
:]
=
v
[
b
,
h
,
:,
:]
return
k_cache
,
v_cache
def
infinicore_kv_caching
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
):
infinicore
.
kv_caching
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
)
return
k_cache
,
v_cache
class
OpTest
(
BaseOperatorTest
):
def
__init__
(
self
):
super
().
__init__
(
"KV Caching"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
*
args
,
**
kwargs
):
return
torch_kv_caching
(
*
args
,
**
kwargs
)
def
infinicore_operator
(
self
,
*
args
,
**
kwargs
):
return
infinicore_kv_caching
(
*
args
,
**
kwargs
)
def
main
():
test_runner
=
GenericTestRunner
(
OpTest
)
test_runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment