Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8d1207dd
Commit
8d1207dd
authored
May 19, 2025
by
PanZezhong
Browse files
issue/219 添加attention算子
parent
e580d751
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
423 additions
and
194 deletions
+423
-194
scripts/python_test.py
scripts/python_test.py
+1
-0
src/infiniop/ops/attention/attention.h
src/infiniop/ops/attention/attention.h
+37
-0
src/infiniop/ops/attention/operator.cc
src/infiniop/ops/attention/operator.cc
+286
-0
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
+1
-1
src/infiniop/reduce/cuda/reduce.cuh
src/infiniop/reduce/cuda/reduce.cuh
+1
-1
src/infiniop/tensor.h
src/infiniop/tensor.h
+13
-3
src/infiniop/tensor_descriptor.cc
src/infiniop/tensor_descriptor.cc
+16
-22
test/infiniop/attention.py
test/infiniop/attention.py
+68
-167
No files found.
scripts/python_test.py
View file @
8d1207dd
...
...
@@ -18,6 +18,7 @@ def run_tests(args):
"rms_norm.py"
,
"rope.py"
,
"swiglu.py"
,
"attention.py"
,
]:
result
=
subprocess
.
run
(
f
"python
{
test
}
{
args
}
"
,
text
=
True
,
encoding
=
"utf-8"
,
shell
=
True
...
...
src/infiniop/ops/attention/attention.h
0 → 100644
View file @
8d1207dd
#ifndef ATTENTION_H
#define ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
}; \
}
#endif // ATTENTION_H
src/infiniop/ops/attention/operator.cc
0 → 100644
View file @
8d1207dd
#include "../../operator.h"
#include "../../../utils.h"
#include "../../../utils/check.h"
#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/rearrange.h"
#include <cmath>
#include <cstdint>
struct
InfiniopAttentionDescriptor
{
InfiniopDescriptor
_super
;
infiniopRearrangeDescriptor_t
rearrange_desc_k
;
infiniopRearrangeDescriptor_t
rearrange_desc_v
;
infiniopRearrangeDescriptor_t
rearrange_desc_q
;
infiniopRearrangeDescriptor_t
rearrange_desc_out
;
infiniopGemmDescriptor_t
matmul_desc1
;
infiniopGemmDescriptor_t
matmul_desc2
;
infiniopCausalSoftmaxDescriptor_t
softmax_desc
;
uint64_t
workspace_size
;
uint64_t
rearranged_q_size
;
uint64_t
matmul1_workspace_size
;
uint64_t
matmul1_tensor_size
;
uint64_t
matmul2_workspace_size
;
uint64_t
matmul2_tensor_size
;
uint64_t
softmax_workspace_size
;
uint64_t
k_cache_offset
;
uint64_t
v_cache_offset
;
float
qk_alpha
;
};
__C
__export
infiniStatus_t
infiniopCreateAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
uint64_t
pos
)
{
if
(
out_desc
->
ndim
()
!=
3
||
q_desc
->
ndim
()
!=
3
||
k_desc
->
ndim
()
!=
3
||
v_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
!=
3
||
v_cache_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
!
out_desc
->
isContiguous
(
0
,
2
))
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
if
(
q_desc
->
strides
()[
2
]
!=
1
||
k_desc
->
strides
()[
2
]
!=
1
||
v_desc
->
strides
()[
2
]
!=
1
||
k_cache_desc
->
strides
()[
2
]
!=
1
||
v_cache_desc
->
strides
()[
2
]
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
uint64_t
n_q_head
=
q_desc
->
shape
()[
0
];
uint64_t
seq_len
=
q_desc
->
shape
()[
1
];
uint64_t
head_dim
=
q_desc
->
shape
()[
2
];
uint64_t
hidden_size
=
n_q_head
*
head_dim
;
uint64_t
n_kv_head
=
k_desc
->
shape
()[
0
];
uint64_t
total_seq_len
=
seq_len
+
pos
;
uint64_t
n_group
=
n_q_head
/
n_kv_head
;
if
(
out_desc
->
shape
()[
0
]
!=
seq_len
||
out_desc
->
shape
()[
1
]
!=
n_q_head
||
out_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// k: [n_kv_head, seq_len, head_dim]
if
(
k_desc
->
shape
()[
0
]
!=
n_kv_head
||
k_desc
->
shape
()[
1
]
!=
seq_len
||
k_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// v: [n_kv_head, seq_len, head_dim]
if
(
v_desc
->
shape
()[
0
]
!=
n_kv_head
||
v_desc
->
shape
()[
1
]
!=
seq_len
||
v_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// k_cache: [n_kv_head, _, head_dim]
if
(
k_cache_desc
->
shape
()[
0
]
!=
n_kv_head
||
k_cache_desc
->
shape
()[
1
]
<
total_seq_len
||
k_cache_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// v_cache: [n_kv_head, _, head_dim]
if
(
v_cache_desc
->
shape
()[
0
]
!=
n_kv_head
||
v_cache_desc
->
shape
()[
1
]
<
total_seq_len
||
v_cache_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// Rearrange k into k_cache
infiniopTensorDescriptor_t
dst_k_desc
;
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
dst_k_desc
,
3
,
k_desc
->
shape
().
data
(),
k_cache_desc
->
strides
().
data
(),
k_cache_desc
->
dtype
()));
infiniopRearrangeDescriptor_t
rearrange_desc_k
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_k
,
dst_k_desc
,
k_desc
));
// Rearrange v into v_cache
infiniopTensorDescriptor_t
dst_v_desc
;
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
dst_v_desc
,
3
,
v_desc
->
shape
().
data
(),
v_cache_desc
->
strides
().
data
(),
v_cache_desc
->
dtype
()));
infiniopRearrangeDescriptor_t
rearrange_desc_v
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_v
,
dst_v_desc
,
v_desc
));
infiniopRearrangeDescriptor_t
rearrange_desc_q
=
nullptr
;
uint64_t
rearranged_q_size
=
0
;
infiniopTensorDescriptor_t
rearranged_q_desc
;
// Rearrange q into contiguous
if
(
!
q_desc
->
isContiguous
(
0
,
1
))
{
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
rearranged_q_desc
,
3
,
q_desc
->
shape
().
data
(),
nullptr
,
q_desc
->
dtype
()));
rearranged_q_size
=
rearranged_q_desc
->
numel
()
*
infiniSizeOf
(
rearranged_q_desc
->
dtype
());
rearrange_desc_q
=
new
InfiniopDescriptor
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_q
,
rearranged_q_desc
,
q_desc
));
}
// Matmul1: q * full_k
// q: [n_q_head, seq_len, head_dim] -> [n_kv_head, n_group *seq_len, head_dim]
infiniopTensorDescriptor_t
reshaped_q_desc
;
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
reshaped_q_desc
,
3
,
q_desc
->
shape
().
data
(),
nullptr
,
q_desc
->
dtype
()));
TRANSFORM_TENSOR_DESC
(
reshaped_q_desc
,
dimSplit
(
0
,
{
n_kv_head
,
n_group
}));
TRANSFORM_TENSOR_DESC
(
reshaped_q_desc
,
dimMerge
(
1
,
2
));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t
full_k_desc
;
uint64_t
full_k_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
full_k_desc
,
3
,
full_k_shape
,
k_cache_desc
->
strides
().
data
(),
k_cache_desc
->
dtype
()));
TRANSFORM_TENSOR_DESC
(
full_k_desc
,
dimPermute
({
0
,
2
,
1
}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t
qk_desc
;
uint64_t
qk_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
total_seq_len
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
qk_desc
,
3
,
qk_shape
,
nullptr
,
q_desc
->
dtype
()));
// matmul1_desc
// qk_alpha
float
qk_alpha
=
1
/
sqrt
(
head_dim
);
infiniopGemmDescriptor_t
matmul1_desc
;
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul1_desc
,
qk_desc
,
reshaped_q_desc
,
full_k_desc
));
// matmul1 workspace size
uint64_t
matmul1_workspace_size
;
CHECK_STATUS
(
infiniopGetGemmWorkspaceSize
(
matmul1_desc
,
&
matmul1_workspace_size
));
// matmul1 tensor size
uint64_t
matmul1_tensor_size
=
qk_desc
->
numel
()
*
infiniSizeOf
(
qk_desc
->
dtype
());
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimSplit
(
1
,
{
n_group
,
seq_len
}));
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimMerge
(
0
,
1
));
infiniopCausalSoftmaxDescriptor_t
softmax_desc
;
CHECK_STATUS
(
infiniopCreateCausalSoftmaxDescriptor
(
handle
,
&
softmax_desc
,
qk_desc
,
qk_desc
));
// softmax workspace size
uint64_t
softmax_workspace_size
;
CHECK_STATUS
(
infiniopGetCausalSoftmaxWorkspaceSize
(
softmax_desc
,
&
softmax_workspace_size
));
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
// full_v: [n_kv_head, total_seq_len, head_dim]
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimSplit
(
0
,
{
n_kv_head
,
n_group
}));
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimMerge
(
1
,
2
));
infiniopTensorDescriptor_t
full_v_desc
;
uint64_t
full_v_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
full_v_desc
,
3
,
full_v_shape
,
v_cache_desc
->
strides
().
data
(),
v_cache_desc
->
dtype
()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t
temp_out_desc
;
uint64_t
temp_out_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
temp_out_desc
,
3
,
temp_out_shape
,
nullptr
,
q_desc
->
dtype
()));
// matmul2_desc
infiniopGemmDescriptor_t
matmul2_desc
;
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul2_desc
,
temp_out_desc
,
qk_desc
,
full_v_desc
));
// matmul2 workspace size
uint64_t
matmul2_workspace_size
;
CHECK_STATUS
(
infiniopGetGemmWorkspaceSize
(
matmul2_desc
,
&
matmul2_workspace_size
));
// matmul2 tensor size
uint64_t
matmul2_tensor_size
=
temp_out_desc
->
numel
()
*
infiniSizeOf
(
temp_out_desc
->
dtype
());
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC
(
temp_out_desc
,
dimSplit
(
1
,
{
n_group
,
seq_len
}));
TRANSFORM_TENSOR_DESC
(
temp_out_desc
,
dimMerge
(
0
,
1
));
TRANSFORM_TENSOR_DESC
(
temp_out_desc
,
dimPermute
({
1
,
0
,
2
}));
infiniopRearrangeDescriptor_t
rearrange_desc_out
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_out
,
out_desc
,
temp_out_desc
));
// workspace size
uint64_t
workspace_size
=
rearranged_q_size
+
std
::
max
(
std
::
max
(
matmul1_workspace_size
+
matmul1_tensor_size
,
matmul1_tensor_size
+
softmax_workspace_size
),
matmul1_tensor_size
+
matmul2_workspace_size
+
matmul2_tensor_size
);
// k_cache_offset
uint64_t
k_cache_offset
=
0
;
if
(
pos
>
0
)
{
k_cache_offset
=
pos
*
k_cache_desc
->
getByteStrides
()[
1
];
}
// v_cache_offset
uint64_t
v_cache_offset
=
0
;
if
(
pos
>
0
)
{
v_cache_offset
=
pos
*
v_cache_desc
->
getByteStrides
()[
1
];
}
// create attention descriptor
*
(
InfiniopAttentionDescriptor
**
)
desc_ptr
=
new
InfiniopAttentionDescriptor
{
{
handle
->
device
,
handle
->
device_id
},
rearrange_desc_k
,
rearrange_desc_v
,
rearrange_desc_q
,
rearrange_desc_out
,
matmul1_desc
,
matmul2_desc
,
softmax_desc
,
workspace_size
,
rearranged_q_size
,
matmul1_workspace_size
,
matmul1_tensor_size
,
matmul2_workspace_size
,
matmul2_tensor_size
,
softmax_workspace_size
,
k_cache_offset
,
v_cache_offset
,
1.
f
/
std
::
sqrt
(
float
(
head_dim
)),
};
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopGetAttentionWorkspaceSize
(
infiniopAttentionDescriptor_t
desc
,
uint64_t
*
size
)
{
*
size
=
((
InfiniopAttentionDescriptor
*
)
desc
)
->
workspace_size
;
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopAttention
(
infiniopAttentionDescriptor_t
desc_
,
void
*
workspace
,
uint64_t
workspace_size
,
void
*
out
,
void
const
*
q
,
void
const
*
k
,
void
const
*
v
,
void
*
k_cache
,
void
*
v_cache
,
void
*
stream
)
{
auto
desc
=
(
InfiniopAttentionDescriptor
*
)
desc_
;
void
*
workspace_
=
workspace
;
if
(
workspace_size
<
desc
->
workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
// STATUS_MEMORY_NOT_ALLOCATED
}
// concat k and v to k_cache and v_cache
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_k
,
(
char
*
)
k_cache
+
desc
->
k_cache_offset
,
k
,
stream
));
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_v
,
(
char
*
)
v_cache
+
desc
->
v_cache_offset
,
v
,
stream
));
// rearrange q into contiguous
void
const
*
_q
=
q
;
if
(
desc
->
rearrange_desc_q
)
{
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_q
,
(
char
*
)
workspace_
,
q
,
stream
));
_q
=
workspace_
;
workspace_
=
(
char
*
)
workspace_
+
desc
->
rearranged_q_size
;
}
// matmul1: q * full_k
CHECK_STATUS
(
infiniopGemm
(
desc
->
matmul_desc1
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
,
workspace_size
-
desc
->
matmul1_tensor_size
,
workspace_
,
_q
,
k_cache
,
desc
->
qk_alpha
,
0.0
,
stream
));
// softmax(qk)
CHECK_STATUS
(
infiniopCausalSoftmax
(
desc
->
softmax_desc
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
,
workspace_size
-
desc
->
matmul1_tensor_size
,
workspace_
,
workspace_
,
stream
));
// matmul2: softmax(qk) * full_v
CHECK_STATUS
(
infiniopGemm
(
desc
->
matmul_desc2
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
+
desc
->
matmul2_tensor_size
,
workspace_size
-
desc
->
matmul1_tensor_size
-
desc
->
matmul2_tensor_size
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
,
workspace_
,
v_cache
,
1.0
,
0.0
,
stream
));
// rearrange out
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_out
,
out
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
,
stream
));
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopDestroyAttentionDescriptor
(
infiniopAttentionDescriptor_t
desc_
)
{
auto
desc
=
(
InfiniopAttentionDescriptor
*
)
desc_
;
if
(
desc
->
rearrange_desc_q
)
{
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_q
));
}
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_k
));
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_v
));
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_out
));
CHECK_STATUS
(
infiniopDestroyGemmDescriptor
(
desc
->
matmul_desc1
));
CHECK_STATUS
(
infiniopDestroyGemmDescriptor
(
desc
->
matmul_desc2
));
CHECK_STATUS
(
infiniopDestroyCausalSoftmaxDescriptor
(
desc
->
softmax_desc
));
delete
desc
;
return
INFINI_STATUS_SUCCESS
;
}
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
View file @
8d1207dd
...
...
@@ -48,7 +48,7 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) {
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
y_
[
j
*
info
->
y_stride_j
]
=
utils
::
cast
<
fp16_t
>
(
utils
::
cast
<
float
>
(
y_
[
j
*
info
->
y_stride_j
])
/
sum
);
}
else
{
y_
[
j
*
info
->
y_stride_j
]
=
y_
[
y_offset
+
j
*
info
->
y_stride_j
]
/
sum
;
y_
[
j
*
info
->
y_stride_j
]
=
y_
[
j
*
info
->
y_stride_j
]
/
sum
;
}
}
}
...
...
src/infiniop/reduce/cuda/reduce.cuh
View file @
8d1207dd
...
...
@@ -18,7 +18,7 @@ __device__ __forceinline__ Tcompute sumSquared(const Tdata *data_ptr, size_t cou
// Each thread computes its partial sum
for
(
size_t
i
=
threadIdx
.
x
;
i
<
count
;
i
+=
BLOCK_SIZE
)
{
ss
+=
Tcompute
(
data_ptr
[
i
]
*
data_ptr
[
i
]);
ss
+=
Tcompute
(
data_ptr
[
i
]
)
*
Tcompute
(
data_ptr
[
i
]);
}
// Use CUB block-level reduction
...
...
src/infiniop/tensor.h
View file @
8d1207dd
...
...
@@ -2,9 +2,19 @@
#define __INFINIOP_TENSOR_H__
#include "infiniop/tensor_descriptor.h"
#include "../utils.h"
#include <string>
#include <vector>
#define TRANSFORM_TENSOR_DESC(__TENSOR_DESC__, __OP__) \
do { \
auto __RESULT__ = __TENSOR_DESC__->__OP__; \
CHECK_RESULT(__RESULT__); \
__TENSOR_DESC__ = __RESULT__.take(); \
} while (0)
struct
InfiniopTensorDescriptor
{
private:
// Datatype
...
...
@@ -32,9 +42,9 @@ public:
bool
hasBroadcastDim
()
const
;
std
::
vector
<
size_t
>
getBroadcastDim
()
const
;
infiniopTensorDescriptor_t
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
;
infiniopTensorDescriptor_t
dimSplit
(
size_t
axis
,
const
std
::
vector
<
size_t
>
&
dims
)
const
;
infiniopTensorDescriptor_t
dimPermute
(
const
std
::
vector
<
size_t
>
&
order
)
const
;
utils
::
Result
<
infiniopTensorDescriptor_t
>
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
;
utils
::
Result
<
infiniopTensorDescriptor_t
>
dimSplit
(
size_t
axis
,
const
std
::
vector
<
size_t
>
&
dims
)
const
;
utils
::
Result
<
infiniopTensorDescriptor_t
>
dimPermute
(
const
std
::
vector
<
size_t
>
&
order
)
const
;
std
::
string
toString
()
const
;
};
...
...
src/infiniop/tensor_descriptor.cc
View file @
8d1207dd
...
...
@@ -12,7 +12,7 @@ __C __export infiniStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescrip
std
::
vector
<
ptrdiff_t
>
strides
(
ndim
);
ptrdiff_t
dsize
=
1
;
if
(
ndim
>
0
)
{
for
(
size_
t
i
=
ndim
-
1
;
i
>=
0
;
i
--
)
{
for
(
in
t
i
=
(
int
)
ndim
-
1
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
dsize
;
dsize
*=
shape_
[
i
];
}
...
...
@@ -104,10 +104,8 @@ std::vector<size_t> InfiniopTensorDescriptor::getBroadcastDim() const {
return
res
;
}
infiniopTensorDescriptor_t
InfiniopTensorDescriptor
::
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
{
if
(
dim_start
>
dim_end
||
dim_end
>=
ndim
())
{
return
nullptr
;
}
utils
::
Result
<
infiniopTensorDescriptor_t
>
InfiniopTensorDescriptor
::
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
{
CHECK_OR_RETURN
(
dim_start
<=
dim_end
&&
dim_end
<
ndim
(),
INFINI_STATUS_BAD_PARAM
);
size_t
new_ndim
=
ndim
()
-
(
dim_end
-
dim_start
);
std
::
vector
<
size_t
>
new_shape
(
new_ndim
);
...
...
@@ -120,9 +118,7 @@ infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimMerge(size_t dim_start,
index
++
;
}
if
(
!
isContiguous
(
dim_start
,
dim_end
))
{
return
nullptr
;
}
CHECK_OR_RETURN
(
isContiguous
(
dim_start
,
dim_end
),
INFINI_STATUS_BAD_PARAM
);
new_shape
[
index
]
=
1
;
for
(
size_t
i
=
dim_start
;
i
<=
dim_end
;
i
++
)
{
...
...
@@ -138,15 +134,15 @@ infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimMerge(size_t dim_start,
index
++
;
}
return
new
InfiniopTensorDescriptor
(
_dtype
,
new_ndim
,
new_shape
.
data
(),
new_strides
.
data
());
return
utils
::
Result
<
infiniopTensorDescriptor_t
>
(
new
InfiniopTensorDescriptor
(
_dtype
,
new_ndim
,
new_shape
.
data
(),
new_strides
.
data
()));
}
infiniopTensorDescriptor_t
InfiniopTensorDescriptor
::
dimSplit
(
size_t
axis
,
const
std
::
vector
<
size_t
>
&
dims
)
const
{
utils
::
Result
<
infiniopTensorDescriptor_t
>
InfiniopTensorDescriptor
::
dimSplit
(
size_t
axis
,
const
std
::
vector
<
size_t
>
&
dims
)
const
{
size_t
ndim_
=
ndim
();
if
(
dim
(
axis
)
!=
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
(
size_t
)
1
,
std
::
multiplies
<
size_t
>
()))
{
return
nullptr
;
}
CHECK_OR_RETURN
(
dim
(
axis
)
==
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
(
size_t
)
1
,
std
::
multiplies
<
size_t
>
()),
INFINI_STATUS_BAD_PARAM
);
size_t
new_ndim
=
ndim_
+
dims
.
size
()
-
1
;
std
::
vector
<
size_t
>
new_shape
(
new_ndim
);
...
...
@@ -168,24 +164,22 @@ infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimSplit(size_t axis, const
index
++
;
}
return
new
InfiniopTensorDescriptor
(
_dtype
,
new_ndim
,
new_shape
.
data
(),
new_strides
.
data
());
return
utils
::
Result
<
infiniopTensorDescriptor_t
>
(
new
InfiniopTensorDescriptor
(
_dtype
,
new_ndim
,
new_shape
.
data
(),
new_strides
.
data
()));
}
infiniopTensorDescriptor_t
InfiniopTensorDescriptor
::
dimPermute
(
const
std
::
vector
<
size_t
>
&
order
)
const
{
utils
::
Result
<
infiniopTensorDescriptor_t
>
InfiniopTensorDescriptor
::
dimPermute
(
const
std
::
vector
<
size_t
>
&
order
)
const
{
auto
ndim_
=
ndim
();
if
(
order
.
size
()
!=
ndim_
)
{
return
nullptr
;
}
CHECK_OR_RETURN
(
order
.
size
()
==
ndim_
,
INFINI_STATUS_BAD_PARAM
);
std
::
vector
<
size_t
>
new_shape
(
ndim_
);
std
::
vector
<
ptrdiff_t
>
new_strides
(
ndim_
);
for
(
size_t
i
=
0
;
i
<
ndim_
;
i
++
)
{
if
(
std
::
find
(
order
.
begin
(),
order
.
end
(),
i
)
==
order
.
end
())
{
return
nullptr
;
}
CHECK_OR_RETURN
(
std
::
find
(
order
.
begin
(),
order
.
end
(),
i
)
!=
order
.
end
(),
INFINI_STATUS_BAD_PARAM
);
new_shape
[
i
]
=
dim
(
order
[
i
]);
new_strides
[
i
]
=
stride
(
order
[
i
]);
}
return
new
InfiniopTensorDescriptor
(
_dtype
,
ndim_
,
new_shape
.
data
(),
new_strides
.
data
());
return
utils
::
Result
<
infiniopTensorDescriptor_t
>
(
new
InfiniopTensorDescriptor
(
_dtype
,
ndim_
,
new_shape
.
data
(),
new_strides
.
data
()));
}
std
::
string
InfiniopTensorDescriptor
::
toString
()
const
{
...
...
test/infiniop/attention.py
View file @
8d1207dd
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
,
c_float
,
c_bool
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
import
ctypes
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
operatorspy
import
(
from
libinfiniop
import
(
open_lib
,
to_tensor
,
CTensor
,
DeviceEnum
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
create_handle
,
destroy_handle
,
check_error
,
rearrange_tensor
,
create_workspace
,
get_args
,
get_test_devices
,
test_operator
,
debug
,
get_tolerance
,
profile_operation
,
)
from
operatorspy.tests.test_utils
import
get_args
import
torch
import
torch.nn.functional
as
F
class
AttentionDescriptor
(
Structure
):
...
...
@@ -95,13 +95,13 @@ def test(
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
=
torch
.
float16
,
q_stride
=
None
,
k_stride
=
None
,
v_stride
=
None
,
k_cache_stride
=
None
,
v_cache_stride
=
None
,
sync
=
None
dtype
=
torch
.
float16
,
sync
=
None
,
):
print
(
f
"Testing Attention on
{
torch_device
}
with n_q_head:
{
n_q_head
}
n_kv_head:
{
n_kv_head
}
seq_len:
{
seq_len
}
head_dim:
{
head_dim
}
pos:
{
pos
}
"
...
...
@@ -140,7 +140,7 @@ def test(
v_tensor
=
to_tensor
(
v
,
lib
)
k_cache_tensor
=
to_tensor
(
k_cache
,
lib
)
v_cache_tensor
=
to_tensor
(
v_cache
,
lib
)
if
sync
is
not
None
:
sync
()
...
...
@@ -160,12 +160,15 @@ def test(
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
out_tensor
.
descriptor
.
contents
.
invalidate
()
q_tensor
.
descriptor
.
contents
.
invalidate
()
k_tensor
.
descriptor
.
contents
.
invalidate
()
v_tensor
.
descriptor
.
contents
.
invalidate
()
k_cache_tensor
.
descriptor
.
contents
.
invalidate
()
v_cache_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
out_tensor
,
q_tensor
,
k_tensor
,
v_tensor
,
k_cache_tensor
,
v_cache_tensor
,
]:
tensor
.
destroyDesc
(
lib
)
workspace_size
=
c_uint64
(
0
)
check_error
(
...
...
@@ -173,152 +176,52 @@ def test(
)
workspace
=
create_workspace
(
workspace_size
.
value
,
out
.
device
)
check_error
(
lib
.
infiniopAttention
(
descriptor
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
workspace_size
.
value
,
out_tensor
.
data
,
q_tensor
.
data
,
k_tensor
.
data
,
v_tensor
.
data
,
k_cache_tensor
.
data
,
v_cache_tensor
.
data
,
None
,
def
lib_attention
():
check_error
(
lib
.
infiniopAttention
(
descriptor
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
workspace_size
.
value
,
out_tensor
.
data
,
q_tensor
.
data
,
k_tensor
.
data
,
v_tensor
.
data
,
k_cache_tensor
.
data
,
v_cache_tensor
.
data
,
None
,
)
)
)
assert
torch
.
allclose
(
out
,
ans
,
atol
=
1e-4
,
rtol
=
1e-2
)
lib_attention
(
)
check_error
(
lib
.
infiniopDestroyAttentionDescriptor
(
descriptor
))
def
test_cpu
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CPU
handle
=
create_handle
(
lib
,
device
)
for
(
n_q_head
,
n_kv_head
,
seq_len
,
head_dim
,
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
,
q_stride
,
k_stride
,
v_stride
,
k_cache_stride
,
v_cache_stride
,
)
in
test_cases
:
test
(
lib
,
handle
,
"cpu"
,
n_q_head
,
n_kv_head
,
seq_len
,
head_dim
,
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
,
q_stride
,
k_stride
,
v_stride
,
k_cache_stride
,
v_cache_stride
,
)
# Validate results
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
out
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
out
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
destroy_handle
(
lib
,
handle
)
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
(
n_q_head
,
n_kv_head
,
seq_len
,
head_dim
,
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
,
q_stride
,
k_stride
,
v_stride
,
k_cache_stride
,
v_cache_stride
,
)
in
test_cases
:
test
(
lib
,
handle
,
"cuda"
,
n_q_head
,
n_kv_head
,
seq_len
,
head_dim
,
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
,
q_stride
,
k_stride
,
v_stride
,
k_cache_stride
,
v_cache_stride
,
)
destroy_handle
(
lib
,
handle
)
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
(
n_q_head
,
n_kv_head
,
seq_len
,
head_dim
,
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
,
q_stride
,
k_stride
,
v_stride
,
k_cache_stride
,
v_cache_stride
,
)
in
test_cases
:
test
(
lib
,
handle
,
"mlu"
,
n_q_head
,
n_kv_head
,
seq_len
,
head_dim
,
pos
,
k_cache_buf_len
,
v_cache_buf_len
,
dtype
,
q_stride
,
k_stride
,
v_stride
,
k_cache_stride
,
v_cache_stride
,
)
destroy_handle
(
lib
,
handle
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
attention
(
q
,
k
,
v
,
k_cache
,
v_cache
,
pos
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_attention
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroyAttentionDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
"atol"
:
1e-6
,
"rtol"
:
1e-4
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
test_cases
=
[
# prefill
(
...
...
@@ -329,7 +232,6 @@ if __name__ == "__main__":
0
,
# pos
2048
,
# k_cache_buf_len
2048
,
# v_cache_buf_len
torch
.
float16
,
# dtype
[
64
,
2560
,
1
],
# q_stride
[
64
,
2560
,
1
],
# k_stride
[
64
,
2560
,
1
],
# v_stride
...
...
@@ -345,7 +247,6 @@ if __name__ == "__main__":
3
,
# pos
2048
,
# k_cache_buf_len
2048
,
# v_cache_buf_len
torch
.
float16
,
# dtype
[
64
,
2560
,
1
],
# q_stride
[
64
,
2560
,
1
],
# k_stride
[
64
,
2560
,
1
],
# v_stride
...
...
@@ -361,7 +262,6 @@ if __name__ == "__main__":
1
,
# pos
8
,
# k_cache_buf_len
8
,
# v_cache_buf_len
torch
.
float16
,
# dtype
None
,
# q_stride
None
,
# k_stride
None
,
# v_stride
...
...
@@ -410,12 +310,13 @@ if __name__ == "__main__":
infiniopAttentionDescriptor_t
,
]
if
args
.
cpu
:
test_cpu
(
lib
,
test_cases
)
if
args
.
cuda
:
test_cuda
(
lib
,
test_cases
)
if
args
.
bang
:
test_bang
(
lib
,
test_cases
)
if
not
(
args
.
cpu
or
args
.
cuda
or
args
.
bang
):
test_cpu
(
lib
,
test_cases
)
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
# Execute tests
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
test_cases
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment