Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
38078981
Commit
38078981
authored
Dec 29, 2025
by
pengcheng888
Committed by
PanZezhong
Dec 29, 2025
Browse files
issue/847-paged caching和atention添加infinicore的接口和测试
parent
298feac2
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
719 additions
and
4 deletions
+719
-4
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+2
-0
include/infinicore/ops/paged_attention.hpp
include/infinicore/ops/paged_attention.hpp
+18
-0
include/infinicore/ops/paged_caching.hpp
include/infinicore/ops/paged_caching.hpp
+17
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+4
-0
python/infinicore/ops/paged_attention.py
python/infinicore/ops/paged_attention.py
+40
-0
python/infinicore/ops/paged_caching.py
python/infinicore/ops/paged_caching.py
+21
-0
src/infinicore/ops/paged_attention/paged_attention.cc
src/infinicore/ops/paged_attention/paged_attention.cc
+28
-0
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
...nfinicore/ops/paged_attention/paged_attention_infiniop.cc
+54
-0
src/infinicore/ops/paged_caching/paged_caching.cc
src/infinicore/ops/paged_caching/paged_caching.cc
+22
-0
src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
+50
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+4
-0
src/infinicore/pybind11/ops/paged_attention.hpp
src/infinicore/pybind11/ops/paged_attention.hpp
+53
-0
src/infinicore/pybind11/ops/paged_caching.hpp
src/infinicore/pybind11/ops/paged_caching.hpp
+22
-0
test/infinicore/ops/paged_attention.py
test/infinicore/ops/paged_attention.py
+205
-0
test/infinicore/ops/paged_caching.py
test/infinicore/ops/paged_caching.py
+177
-0
test/infiniop/paged_attention.py
test/infiniop/paged_attention.py
+2
-4
No files found.
include/infinicore/ops.hpp
View file @
38078981
...
...
@@ -5,6 +5,8 @@
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
...
...
include/infinicore/ops/paged_attention.hpp
0 → 100644
View file @
38078981
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
class
PagedAttention
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
std
::
optional
<
Tensor
>
,
float
);
static
void
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
include/infinicore/ops/paged_caching.hpp
0 → 100644
View file @
38078981
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
class
PagedCaching
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
);
static
void
execute
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
void
paged_caching_
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
);
}
// namespace infinicore::op
python/infinicore/__init__.py
View file @
38078981
...
...
@@ -44,6 +44,8 @@ from infinicore.ops.attention import attention
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.narrow
import
narrow
from
infinicore.ops.paged_attention
import
paged_attention
from
infinicore.ops.paged_caching
import
paged_caching
from
infinicore.ops.rearrange
import
rearrange
from
infinicore.ops.squeeze
import
squeeze
from
infinicore.ops.unsqueeze
import
unsqueeze
...
...
@@ -115,6 +117,8 @@ __all__ = [
"from_list"
,
"from_numpy"
,
"from_torch"
,
"paged_caching"
,
"paged_attention"
,
"ones"
,
"strided_empty"
,
"strided_from_blob"
,
...
...
python/infinicore/ops/paged_attention.py
0 → 100644
View file @
38078981
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
paged_attention
(
q
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
block_tables
:
Tensor
,
seq_lens
:
Tensor
,
alibi_slopes
:
Tensor
|
None
=
None
,
scale
:
float
=
1.0
,
*
,
out
:
Tensor
|
None
=
None
,
):
if
out
is
None
:
return
Tensor
(
_infinicore
.
paged_attention
(
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
seq_lens
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
)
_infinicore
.
paged_attention_
(
out
.
_underlying
,
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
block_tables
.
_underlying
,
seq_lens
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
return
out
python/infinicore/ops/paged_caching.py
0 → 100644
View file @
38078981
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
paged_caching
(
k
:
Tensor
,
v
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
slot_mapping
:
Tensor
,
):
Tensor
(
_infinicore
.
paged_caching_
(
k
.
_underlying
,
v
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
slot_mapping
.
_underlying
,
)
)
return
(
k_cache
,
v_cache
)
src/infinicore/ops/paged_attention/paged_attention.cc
0 → 100644
View file @
38078981
#include "infinicore/ops/paged_attention.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
PagedAttention
::
schema
>
&
PagedAttention
::
dispatcher
()
{
static
common
::
OpDispatcher
<
PagedAttention
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
void
PagedAttention
::
execute
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
);
infinicore
::
context
::
setDevice
(
out
->
device
());
dispatcher
().
lookup
(
out
->
device
().
getType
())(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
scale
);
}
Tensor
paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
scale
);
return
out
;
}
void
paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
PagedAttention
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
scale
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
0 → 100644
View file @
38078981
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
paged_attention_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopPagedAttentionDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopPagedAttentionDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyPagedAttentionDescriptor
(
desc
));
desc
=
nullptr
;
}
});
void
calculate
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
size_t
seed
=
hash_combine
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopPagedAttentionDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedAttentionDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
out
->
desc
(),
q
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
block_tables
->
desc
(),
seq_lens
->
desc
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
desc
()
:
nullptr
,
scale
));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetPagedAttentionWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopPagedAttention
(
desc
,
workspace
->
data
(),
workspace_size
,
out
->
data
(),
q
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
block_tables
->
data
(),
seq_lens
->
data
(),
alibi_slopes
.
has_value
()
?
alibi_slopes
.
value
()
->
data
()
:
nullptr
,
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
PagedAttention
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::paged_attention_impl::infiniop
src/infinicore/ops/paged_caching/paged_caching.cc
0 → 100644
View file @
38078981
#include "infinicore/ops/paged_caching.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
PagedCaching
::
schema
>
&
PagedCaching
::
dispatcher
()
{
static
common
::
OpDispatcher
<
PagedCaching
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
void
PagedCaching
::
execute
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
infinicore
::
context
::
setDevice
(
k
->
device
());
dispatcher
().
lookup
(
k
->
device
().
getType
())(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
}
void
paged_caching_
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
)
{
PagedCaching
::
execute
(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
}
}
// namespace infinicore::op
src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
0 → 100644
View file @
38078981
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_caching.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
paged_caching_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopPagedCachingDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopPagedCachingDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyPagedCachingDescriptor
(
desc
));
desc
=
nullptr
;
}
});
void
calculate
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
)
{
size_t
seed
=
hash_combine
(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopPagedCachingDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedCachingDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
k
->
desc
(),
v
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
slot_mapping
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetPagedCachingWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopPagedCaching
(
desc
,
workspace
->
data
(),
workspace_size
,
k
->
data
(),
v
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
slot_mapping
->
data
(),
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
PagedCaching
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::paged_caching_impl::infiniop
src/infinicore/pybind11/ops.hpp
View file @
38078981
...
...
@@ -9,6 +9,8 @@
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
...
...
@@ -28,6 +30,8 @@ inline void bind(py::module &m) {
bind_linear
(
m
);
bind_matmul
(
m
);
bind_mul
(
m
);
bind_paged_attention
(
m
);
bind_paged_caching
(
m
);
bind_rearrange
(
m
);
bind_rms_norm
(
m
);
bind_silu
(
m
);
...
...
src/infinicore/pybind11/ops/paged_attention.hpp
0 → 100644
View file @
38078981
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/paged_attention.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
Tensor
py_paged_attention
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
return
op
::
paged_attention
(
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes_tensor
,
scale
);
}
void
py_paged_attention_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
block_tables
,
Tensor
seq_lens
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
op
::
paged_attention_
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
seq_lens
,
alibi_slopes_tensor
,
scale
);
}
inline
void
bind_paged_attention
(
py
::
module
&
m
)
{
m
.
def
(
"paged_attention"
,
&
ops
::
py_paged_attention
,
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"seq_lens"
),
py
::
arg
(
"alibi_slopes"
),
py
::
arg
(
"scale"
),
R"doc(Paged attention of query and key cache tensors.)doc"
);
m
.
def
(
"paged_attention_"
,
&
ops
::
py_paged_attention_
,
py
::
arg
(
"out"
),
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"block_tables"
),
py
::
arg
(
"seq_lens"
),
py
::
arg
(
"alibi_slopes"
),
py
::
arg
(
"scale"
),
R"doc(In-place paged attention of query and key cache tensors.)doc"
);
}
}
// namespace infinicore::ops
src/infinicore/pybind11/ops/paged_caching.hpp
0 → 100644
View file @
38078981
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/paged_caching.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
inline
void
bind_paged_caching
(
py
::
module
&
m
)
{
m
.
def
(
"paged_caching_"
,
&
op
::
paged_caching_
,
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"slot_mapping"
),
R"doc(Paged caching of key and value tensors.)doc"
);
}
}
// namespace infinicore::ops
test/infinicore/ops/paged_attention.py
0 → 100644
View file @
38078981
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
import
torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
is_broadcast
,
TensorInitializer
,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format:
_TEST_CASES_DATA
=
[
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(
1
,
1
,
1
,
128
,
16
,
15
,
False
),
# (4, 40, 40, 128, 16, 1024, False),
# (6, 40, 40, 128, 16, 1024, False),
# (3, 8, 8, 128, 16, 1024, False),
# (8, 64, 8, 128, 16, 2048, False),
]
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-3
},
infinicore
.
bfloat16
:
{
"atol"
:
0
,
"rtol"
:
5e-2
},
}
# Data types to test
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
parse_test_cases
():
"""
Parse test case data and return list of TestCase objects for paged_attention operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases
=
[]
for
(
num_seqs
,
num_heads
,
num_kv_heads
,
head_size
,
block_size
,
max_seq_len
,
use_alibi
,
)
in
_TEST_CASES_DATA
:
scale
=
1.0
/
(
head_size
**
0.5
)
max_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
num_blocks
=
num_seqs
*
max_blocks_per_seq
# A reasonable number for testing
seq_lens_torch
=
torch
.
randint
(
1
,
1024
,
(
num_seqs
,),
dtype
=
torch
.
int32
)
# seq_lens_torch = torch.ones(
# (num_seqs,), dtype=torch.int32
# )
block_tables
=
torch
.
arange
(
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int32
).
view
(
num_seqs
,
max_blocks_per_seq
)
print
(
"block_tables.shape"
,
block_tables
.
shape
,
block_tables
)
q_shape
=
(
num_seqs
,
num_heads
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
block_tables_shape
=
block_tables
.
shape
seq_lens_shape
=
seq_lens_torch
.
shape
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
,
{
"atol"
:
0
,
"rtol"
:
1e-3
})
# Create typed tensor specs
q_spec
=
TensorSpec
.
from_tensor
(
q_shape
,
None
,
dtype
)
k_cache_spec
=
TensorSpec
.
from_tensor
(
k_cache_shape
,
None
,
dtype
)
v_cache_spec
=
TensorSpec
.
from_tensor
(
v_cache_shape
,
None
,
dtype
)
block_tables_spec
=
TensorSpec
.
from_tensor
(
block_tables_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
block_tables
,
dtype
=
infinicore
.
int32
,
)
seq_lens_spec
=
TensorSpec
.
from_tensor
(
seq_lens_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
seq_lens_torch
,
dtype
=
infinicore
.
int32
,
)
# Paged attention operation: returns output tensor
out_shape
=
(
num_seqs
,
num_heads
,
head_size
)
out_spec
=
TensorSpec
.
from_tensor
(
out_shape
,
None
,
dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
q_spec
,
k_cache_spec
,
v_cache_spec
,
block_tables_spec
,
seq_lens_spec
,
],
kwargs
=
{
"alibi_slopes"
:
None
,
"scale"
:
scale
},
output_spec
=
None
,
comparison_target
=
0
,
tolerance
=
tolerance
,
description
=
f
"PagedAttention"
,
)
)
return
test_cases
def
ref_masked_attention
(
query
,
key
,
value
,
scale
,
attn_mask
=
None
):
# Reference implementation for a single masked attention head.
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
def
ref_single_query_cached_kv_attention
(
query
,
key_cache
,
value_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
scale
):
# Reference implementation for paged attention, iterating through each sequence.
output
=
torch
.
empty_like
(
query
)
num_query_heads
,
num_kv_heads
=
query
.
shape
[
1
],
value_cache
.
shape
[
1
]
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_size
,
block_size
=
value_cache
.
shape
[
3
],
value_cache
.
shape
[
2
]
num_seqs
=
query
.
shape
[
0
]
for
i
in
range
(
num_seqs
):
q
=
query
[
i
].
unsqueeze
(
0
)
seq_len
=
seq_lens
[
i
].
item
()
block_table
=
block_tables
[
i
]
keys_lst
,
values_lst
=
[],
[]
for
j
in
range
(
seq_len
):
block_num
=
block_table
[
j
//
block_size
].
item
()
block_off
=
j
%
block_size
k
=
key_cache
[
block_num
,
:,
block_off
,
:]
v
=
value_cache
[
block_num
,
:,
block_off
,
:]
keys_lst
.
append
(
k
)
values_lst
.
append
(
v
)
keys
=
torch
.
stack
(
keys_lst
,
dim
=
0
)
values
=
torch
.
stack
(
values_lst
,
dim
=
0
)
if
num_queries_per_kv
>
1
:
keys
=
torch
.
repeat_interleave
(
keys
,
num_queries_per_kv
,
dim
=
1
)
values
=
torch
.
repeat_interleave
(
values
,
num_queries_per_kv
,
dim
=
1
)
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
pos
=
torch
.
arange
(
seq_len
,
device
=
query
.
device
).
int
()
alibi_bias
=
(
pos
-
seq_len
+
1
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
out
=
ref_masked_attention
(
q
,
keys
,
values
,
scale
,
alibi_bias
)
output
[
i
]
=
out
.
view
(
num_query_heads
,
head_size
)
return
output
class
OpTest
(
BaseOperatorTest
):
"""PagedAttention operator test with simplified implementation"""
def
__init__
(
self
):
super
().
__init__
(
"PagedAttention"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
*
args
,
**
kwargs
):
"""PyTorch paged_caching implementation"""
return
ref_single_query_cached_kv_attention
(
*
args
,
**
kwargs
)
def
infinicore_operator
(
self
,
*
args
,
**
kwargs
):
"""InfiniCore paged_attention implementation"""
out
=
infinicore
.
paged_attention
(
*
args
,
**
kwargs
)
infinicore
.
sync_stream
()
return
out
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
test/infinicore/ops/paged_caching.py
0 → 100644
View file @
38078981
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
import
torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
is_broadcast
,
TensorInitializer
,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
_TEST_CASES_DATA
=
[
(
1
,
128
,
8
,
128
,
16
),
(
5
,
512
,
40
,
128
,
16
),
(
16
,
1024
,
8
,
64
,
32
),
(
10
,
1024
,
40
,
64
,
32
),
]
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-3
},
infinicore
.
bfloat16
:
{
"atol"
:
0
,
"rtol"
:
5e-2
},
}
# Data types to test
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
ref_paged_caching
(
key
,
value
,
key_cache_pool
,
value_cache_pool
,
slot_mapping
):
"""
Reference implementation for paged_caching operator.
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok
=
key
.
shape
[
0
]
block_size
=
key_cache_pool
.
shape
[
2
]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref
=
key_cache_pool
.
clone
()
v_cache_ref
=
value_cache_pool
.
clone
()
for
i
in
range
(
ntok
):
slot
=
slot_mapping
[
i
].
item
()
block_idx
=
slot
//
block_size
block_offset
=
slot
%
block_size
key_token
=
key
[
i
]
value_token
=
value
[
i
]
k_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
key_token
v_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
value_token
return
k_cache_ref
,
v_cache_ref
def
parse_test_cases
():
"""
Parse test case data and return list of TestCase objects for paged_caching operation.
Each test case contains all necessary information for execution and validation.
"""
test_cases
=
[]
for
num_seqs
,
max_seq_len
,
num_kv_heads
,
head_size
,
block_size
in
_TEST_CASES_DATA
:
num_blocks
=
4096
# A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
context_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int32
)
ntok
=
torch
.
sum
(
context_lens_torch
).
item
()
# Simulate the scheduler's behavior to create the slot_mapping
slot_mapping_list
=
[]
current_slot
=
0
for
length
in
context_lens_torch
:
# Find a contiguous chunk of 'length' slots
start_slot
=
current_slot
slot_mapping_list
.
extend
(
range
(
start_slot
,
start_slot
+
length
.
item
()))
current_slot
+=
length
.
item
()
# Ensure we don't exceed the total number of slots in the cache
assert
current_slot
<=
num_blocks
*
block_size
,
(
"Not enough blocks in the cache pool for this test case"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int32
)
# print("slot_mapping", slot_mapping)
slot_mapping_shape
=
slot_mapping
.
shape
k_shape
=
(
ntok
,
num_kv_heads
,
head_size
)
v_shape
=
(
ntok
,
num_kv_heads
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
,
{
"atol"
:
0
,
"rtol"
:
1e-3
})
# Create typed tensor specs
k_spec
=
TensorSpec
.
from_tensor
(
k_shape
,
None
,
dtype
)
v_spec
=
TensorSpec
.
from_tensor
(
v_shape
,
None
,
dtype
)
k_cache_spec
=
TensorSpec
.
from_tensor
(
k_cache_shape
,
None
,
dtype
)
v_cache_spec
=
TensorSpec
.
from_tensor
(
v_cache_shape
,
None
,
dtype
)
slot_mapping_spec
=
TensorSpec
.
from_tensor
(
slot_mapping_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
slot_mapping
,
dtype
=
infinicore
.
int32
,
)
# In-place operation: modifies k_cache (index 2) and v_cache (index 3)
test_cases
.
append
(
TestCase
(
inputs
=
[
k_spec
,
v_spec
,
k_cache_spec
,
v_cache_spec
,
slot_mapping_spec
,
],
kwargs
=
None
,
output_spec
=
None
,
comparison_target
=
0
,
# Only compare k_cache
tolerance
=
tolerance
,
description
=
f
"PagedCaching"
,
)
)
return
test_cases
class
OpTest
(
BaseOperatorTest
):
"""PagedCaching operator test with simplified implementation"""
def
__init__
(
self
):
super
().
__init__
(
"PagedCaching"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
*
args
,
**
kwargs
):
"""PyTorch paged_caching implementation"""
return
ref_paged_caching
(
*
args
,
**
kwargs
)
def
infinicore_operator
(
self
,
*
args
,
**
kwargs
):
"""InfiniCore paged_caching implementation"""
return
infinicore
.
paged_caching
(
*
args
,
**
kwargs
)
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
test/infiniop/paged_attention.py
View file @
38078981
...
...
@@ -148,10 +148,8 @@ def test(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
seq_lens_direct
=
1023
seq_lens_torch
=
torch
.
randint
(
1
,
seq_lens_direct
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
seq_lens_torch
=
torch
.
randint
(
1
,
1024
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
seq_lens
=
TestTensor
.
from_torch
(
seq_lens_torch
,
InfiniDtype
.
I64
,
device
)
block_tables_py
=
torch
.
arange
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment