Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ce2c4813
Commit
ce2c4813
authored
Jun 03, 2025
by
Catheriany
Browse files
issue/228: clip算子更新
parents
6bb801f6
6ca0e313
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
578 additions
and
64 deletions
+578
-64
src/infiniop/ops/rope/operator.cc
src/infiniop/ops/rope/operator.cc
+11
-20
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
+0
-6
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
+6
-0
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
+46
-36
src/utils.h
src/utils.h
+8
-0
src/utils/custom_types.cc
src/utils/custom_types.cc
+1
-1
test/infiniop-test/test_generate/testcases/clip.py
test/infiniop-test/test_generate/testcases/clip.py
+242
-0
test/infiniop/attention.py
test/infiniop/attention.py
+15
-1
test/infiniop/clip.py
test/infiniop/clip.py
+246
-0
test/infiniop/rope.py
test/infiniop/rope.py
+3
-0
No files found.
src/infiniop/ops/rope/operator.cc
View file @
ce2c4813
...
...
@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_ascend.h"
#endif
__C
infiniStatus_t
infiniopCreateRoPEDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -43,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
pos_ids
,
sin_table
,
cos_table
);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
ascendCreateRoPEDescriptor
((
AscendHandle_t
)
handle
,
(
RoPEAscendDescriptor_t
*
)
desc_ptr
,
t
,
pos_ids
,
sin_table
,
cos_table
);
}
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
@@ -90,10 +89,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return
bangGetRoPEWorkspaceSize
((
RoPEBangDescriptor_t
)
desc
,
size
);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
ascendGetRoPEWorkspaceSize
((
RoPEAscendDescriptor_t
)
desc
,
size
);
}
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
@@ -141,12 +138,8 @@ __C infiniStatus_t infiniopRoPE(
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
ascendRoPE
((
RoPEAscendDescriptor_t
)
desc
,
workspace
,
workspace_size
,
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
}
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
@@ -187,10 +180,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
return
bangDestroyRoPEDescriptor
((
RoPEBangDescriptor_t
)
desc
);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
ascendDestroyRoPEDescriptor
((
RoPEAscendDescriptor_t
)
desc
);
}
#ifdef ENABLE_ASCEND_API
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
View file @
ce2c4813
...
...
@@ -26,12 +26,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr
return
INFINI_STATUS_SUCCESS
;
}
extern
"C"
infiniStatus_t
swiglu_kernel_launch
(
void
*
c
,
void
*
a
,
void
*
b
,
infiniDtype_t
dtype
,
size_t
batch
,
size_t
seq
,
size_t
hd
,
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
,
void
*
stream
);
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
...
...
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
View file @
ce2c4813
...
...
@@ -69,5 +69,11 @@ public:
void
*
stream
)
const
;
};
extern
"C"
infiniStatus_t
swiglu_kernel_launch
(
void
*
c
,
void
*
a
,
void
*
b
,
infiniDtype_t
dtype
,
size_t
batch
,
size_t
seq
,
size_t
hd
,
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
,
void
*
stream
);
}
// namespace op::swiglu::ascend
#endif // __ACLNN_SWIGLU_H__
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
View file @
ce2c4813
...
...
@@ -6,15 +6,20 @@ template <typename T>
class
SwigluKernel
{
public:
__aicore__
inline
SwigluKernel
()
{}
__aicore__
inline
void
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
);
__aicore__
inline
void
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
size_t
batch_
,
size_t
seq
,
size_t
hd
,
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
);
__aicore__
inline
void
process
();
private:
__aicore__
inline
void
copyIn
(
int64
_t
i
);
__aicore__
inline
void
compute
(
int64
_t
i
);
__aicore__
inline
void
copyOut
(
int64
_t
i
);
__aicore__
inline
void
copyIn
(
size
_t
i
);
__aicore__
inline
void
compute
(
size
_t
i
);
__aicore__
inline
void
copyOut
(
size
_t
i
);
private:
GlobalTensor
<
T
>
_c_gm
,
_a_gm
,
_b_gm
;
...
...
@@ -23,16 +28,21 @@ private:
TPipe
_pipe
;
float
_beta_value
=
1.0
f
;
int64
_t
_block_idx
,
_tile_len
,
_copy_len
,
size
_t
_block_idx
,
_tile_len
,
_copy_len
,
_batch
,
_seq_len
,
_hidden_size
,
_stride_seq_a
,
_stride_seq_b
,
_stride_seq_c
;
int64_t
_stride_batch_a
=
1
,
_stride_batch_b
=
1
,
_stride_batch_c
=
1
;
};
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
size_t
batch_
,
size_t
seq
,
size_t
hd
,
ptrdiff_t
stride_batch_c
,
ptrdiff_t
stride_batch_a
,
ptrdiff_t
stride_batch_b
,
ptrdiff_t
stride_seq_c
,
ptrdiff_t
stride_seq_a
,
ptrdiff_t
stride_seq_b
)
{
// Init Shape & StrideVariables
_batch
=
batch_
;
_seq_len
=
seq
;
...
...
@@ -46,7 +56,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
_block_idx
=
GetBlockIdx
();
_tile_len
=
_block_idx
<
(
_hidden_size
%
BLOCK_NUM
)
?
(
_hidden_size
/
BLOCK_NUM
)
+
1
:
(
_hidden_size
/
BLOCK_NUM
);
_copy_len
=
(
_tile_len
*
sizeof
(
T
))
%
BYTE_ALIGN
==
0
?
_tile_len
:
(
_tile_len
*
sizeof
(
T
)
+
(
BYTE_ALIGN
-
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
))
/
sizeof
(
T
);
_copy_len
=
alignTileLen
<
T
>
(
_tile_len
,
BYTE_ALIGN
);
// Set global tensor
_a_gm
.
SetGlobalBuffer
((
__gm__
T
*
)
a
);
...
...
@@ -60,7 +70,7 @@ __aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, in
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
copyIn
(
int64
_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
copyIn
(
size
_t
i
)
{
// Alloc tensor from queue memory
LocalTensor
<
T
>
aLocal
=
_in_queue_a
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_in_queue_b
.
AllocTensor
<
T
>
();
...
...
@@ -68,8 +78,8 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
auto
batch_idx
=
_batch
==
1
?
0
:
i
/
_seq_len
;
auto
seq_idx
=
_batch
==
1
?
i
:
i
%
_seq_len
;
int64
_t
idxa
=
batch_idx
*
_stride_batch_a
+
seq_idx
*
_stride_seq_a
+
_block_idx
*
_tile_len
;
int64
_t
idxb
=
batch_idx
*
_stride_batch_b
+
seq_idx
*
_stride_seq_b
+
_block_idx
*
_tile_len
;
ptrdiff
_t
idxa
=
batch_idx
*
_stride_batch_a
+
seq_idx
*
_stride_seq_a
+
_block_idx
*
_tile_len
;
ptrdiff
_t
idxb
=
batch_idx
*
_stride_batch_b
+
seq_idx
*
_stride_seq_b
+
_block_idx
*
_tile_len
;
// Copy process_th tile from global tensor to local tensor
DataCopy
(
aLocal
,
_a_gm
[
idxa
],
_copy_len
);
DataCopy
(
bLocal
,
_b_gm
[
idxb
],
_copy_len
);
...
...
@@ -80,7 +90,7 @@ __aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
compute
(
int64
_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
compute
(
size
_t
i
)
{
// Deque input tensors from VECIN queue
LocalTensor
<
T
>
aLocal
=
_in_queue_a
.
DeQue
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_in_queue_b
.
DeQue
<
T
>
();
...
...
@@ -94,12 +104,12 @@ __aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
copyOut
(
int64
_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
copyOut
(
size
_t
i
)
{
// Deque output tensor from VECOUT queue
LocalTensor
<
T
>
cLocal
=
_out_queue_c
.
DeQue
<
T
>
();
auto
batch_idx
=
_batch
==
1
?
0
:
i
/
_seq_len
;
auto
seq_idx
=
_batch
==
1
?
i
:
i
%
_seq_len
;
int64
_t
idxc
=
batch_idx
*
_stride_batch_c
+
seq_idx
*
_stride_seq_c
+
_block_idx
*
_tile_len
;
ptrdiff
_t
idxc
=
batch_idx
*
_stride_batch_c
+
seq_idx
*
_stride_seq_c
+
_block_idx
*
_tile_len
;
// Copy progress_th tile from local tensor to global tensor
if
(
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
!=
0
)
{
DataCopyExtParams
dcep
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
...
...
@@ -113,28 +123,28 @@ __aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
process
()
{
for
(
int64
_t
i
=
0
;
i
<
_batch
*
_seq_len
;
++
i
)
{
for
(
size
_t
i
=
0
;
i
<
_batch
*
_seq_len
;
++
i
)
{
copyIn
(
i
);
compute
(
i
);
copyOut
(
i
);
}
}
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE)
\
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b,
\
int64
_t batch,
int64
_t seq,
int64
_t hd, \
int64
_t stride_batch_c,
\
int64
_t stride_batch_a,
\
int64
_t stride_batch_b,
\
int64
_t stride_seq_c,
\
int64
_t stride_seq_a,
\
int64
_t stride_seq_b) {
\
SwigluKernel<TYPE> op;
\
op.init(c, a, b,
\
batch, seq, hd,
\
stride_batch_c, stride_batch_a, stride_batch_b,
\
stride_seq_c, stride_seq_a, stride_seq_b);
\
op.process();
\
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
size
_t batch,
size
_t seq,
size
_t hd, \
ptrdiff
_t stride_batch_c, \
ptrdiff
_t stride_batch_a, \
ptrdiff
_t stride_batch_b, \
ptrdiff
_t stride_seq_c, \
ptrdiff
_t stride_seq_a, \
ptrdiff
_t stride_seq_b) { \
SwigluKernel<TYPE> op; \
op.init(c, a, b, \
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
op.process(); \
}
DEFINE_SWIGLU_KERNEL
(
swiglu_kernel_half
,
half
)
...
...
@@ -152,9 +162,9 @@ extern "C" infiniStatus_t swiglu_kernel_launch(
case DTYPE_ENUM: \
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
c, a, b, \
static_cast<int64_t>(batch),
\
s
tatic_cast<int64_t>(seq),
\
static_cast<int64_t>(hd),
\
batch,
\
s
eq,
\
hd,
\
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
return INFINI_STATUS_SUCCESS;
...
...
src/utils.h
View file @
ce2c4813
...
...
@@ -100,4 +100,12 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
namespace
utils
{
inline
size_t
align
(
size_t
size
,
size_t
alignment
)
{
return
(
size
+
alignment
-
1
)
&
~
(
alignment
-
1
);
}
}
// namespace utils
#endif
src/utils/custom_types.cc
View file @
ce2c4813
...
...
@@ -43,7 +43,7 @@ fp16_t _f32_to_f16(float val) {
int32_t
exponent
=
((
f32
>>
23
)
&
0xFF
)
-
127
;
// Extract and de-bias the exponent
uint32_t
mantissa
=
f32
&
0x7FFFFF
;
// Extract the mantissa (fraction part)
if
(
exponent
>=
3
1
)
{
// Special cases for Inf and NaN
if
(
exponent
>=
1
6
)
{
// Special cases for Inf and NaN
// NaN
if
(
exponent
==
128
&&
mantissa
!=
0
)
{
return
fp16_t
{
static_cast
<
uint16_t
>
(
sign
|
0x7E00
)};
...
...
test/infiniop-test/test_generate/testcases/clip.py
0 → 100644
View file @
ce2c4813
import
numpy
as
np
import
gguf
from
typing
import
List
,
Optional
,
Tuple
from
..
import
InfiniopTestWriter
,
InfiniopTestCase
,
np_dtype_to_ggml
,
gguf_strides
def
clip
(
x
:
np
.
ndarray
,
min_val
:
np
.
ndarray
,
max_val
:
np
.
ndarray
,
)
->
np
.
ndarray
:
"""
Clip the values in input tensor x to the range [min_val, max_val].
Args:
x: Input tensor
min_val: Tensor with minimum values (same shape as x)
max_val: Tensor with maximum values (same shape as x)
Returns:
Clipped tensor with the same shape as x
"""
return
np
.
maximum
(
np
.
minimum
(
x
,
max_val
),
min_val
)
def
random_tensor
(
shape
,
dtype
):
"""
Generate a random tensor with values in the range [-2, 2].
Args:
shape: Shape of the tensor
dtype: Data type of the tensor
Returns:
Random tensor with the specified shape and dtype
"""
return
(
np
.
random
.
rand
(
*
shape
).
astype
(
dtype
)
*
4.0
-
2.0
)
class
ClipTestCase
(
InfiniopTestCase
):
"""
Test case for the Clip operator.
"""
def
__init__
(
self
,
x
:
np
.
ndarray
,
x_stride
:
Optional
[
List
[
int
]],
min_val
:
np
.
ndarray
,
min_stride
:
Optional
[
List
[
int
]],
max_val
:
np
.
ndarray
,
max_stride
:
Optional
[
List
[
int
]],
y
:
np
.
ndarray
,
y_stride
:
Optional
[
List
[
int
]],
):
super
().
__init__
(
"clip"
)
self
.
x
=
x
self
.
x_stride
=
x_stride
self
.
min_val
=
min_val
self
.
min_stride
=
min_stride
self
.
max_val
=
max_val
self
.
max_stride
=
max_stride
self
.
y
=
y
self
.
y_stride
=
y_stride
def
write_test
(
self
,
test_writer
:
"InfiniopTestWriter"
):
super
().
write_test
(
test_writer
)
# Add strides as arrays if they exist
if
self
.
x_stride
is
not
None
:
test_writer
.
add_array
(
test_writer
.
gguf_key
(
"x.strides"
),
self
.
x_stride
)
if
self
.
min_stride
is
not
None
:
test_writer
.
add_array
(
test_writer
.
gguf_key
(
"min_val.strides"
),
self
.
min_stride
)
if
self
.
max_stride
is
not
None
:
test_writer
.
add_array
(
test_writer
.
gguf_key
(
"max_val.strides"
),
self
.
max_stride
)
if
self
.
y_stride
is
not
None
:
test_writer
.
add_array
(
test_writer
.
gguf_key
(
"y.strides"
),
self
.
y_stride
)
# Add tensors to the test
test_writer
.
add_tensor
(
test_writer
.
gguf_key
(
"x"
),
self
.
x
,
raw_dtype
=
np_dtype_to_ggml
(
self
.
x
.
dtype
)
)
test_writer
.
add_tensor
(
test_writer
.
gguf_key
(
"min_val"
),
self
.
min_val
,
raw_dtype
=
np_dtype_to_ggml
(
self
.
min_val
.
dtype
)
)
test_writer
.
add_tensor
(
test_writer
.
gguf_key
(
"max_val"
),
self
.
max_val
,
raw_dtype
=
np_dtype_to_ggml
(
self
.
max_val
.
dtype
)
)
test_writer
.
add_tensor
(
test_writer
.
gguf_key
(
"y"
),
self
.
y
,
raw_dtype
=
np_dtype_to_ggml
(
self
.
y
.
dtype
)
)
# Calculate the expected result
ans
=
clip
(
self
.
x
.
astype
(
np
.
float64
),
self
.
min_val
.
astype
(
np
.
float64
),
self
.
max_val
.
astype
(
np
.
float64
)
)
# Add the expected result to the test
test_writer
.
add_tensor
(
test_writer
.
gguf_key
(
"ans"
),
ans
,
raw_dtype
=
gguf
.
GGMLQuantizationType
.
F64
)
if
__name__
==
"__main__"
:
test_writer
=
InfiniopTestWriter
(
"clip.gguf"
)
# Create test cases for different shapes, strides, and data types
test_cases
=
[]
# Test case shapes
shapes
=
[
(
10
,),
# 1D tensor
(
5
,
10
),
# 2D tensor
(
2
,
3
,
4
),
# 3D tensor
(
7
,
13
),
# Prime dimensions
(
1
,
1
),
# Minimum shape
(
100
,
100
),
# Large shape
(
16
,
16
,
16
),
# Large 3D
]
# Test case min/max values
min_max_values
=
[
(
-
1.0
,
1.0
),
# Standard range
(
0.0
,
2.0
),
# Positive range
(
-
2.0
,
0.0
),
# Negative range
(
-
1000.0
,
1000.0
),
# Large range
(
-
0.001
,
0.001
),
# Small range
(
0.0
,
0.0
),
# min=max
]
# Data types to test
dtypes
=
[
np
.
float16
,
np
.
float32
,
np
.
float64
]
# Generate test cases with contiguous tensors
for
shape
in
shapes
:
for
min_val
,
max_val
in
min_max_values
:
for
dtype
in
dtypes
:
x
=
random_tensor
(
shape
,
dtype
)
min_tensor
=
np
.
full
(
shape
,
min_val
,
dtype
=
dtype
)
max_tensor
=
np
.
full
(
shape
,
max_val
,
dtype
=
dtype
)
y
=
np
.
zeros
(
shape
,
dtype
=
dtype
)
test_cases
.
append
(
ClipTestCase
(
x
=
x
,
x_stride
=
None
,
min_val
=
min_tensor
,
min_stride
=
None
,
max_val
=
max_tensor
,
max_stride
=
None
,
y
=
y
,
y_stride
=
None
)
)
# Generate test cases with strided tensors (for 2D shapes only)
for
shape
in
[
s
for
s
in
shapes
if
len
(
s
)
==
2
]:
for
dtype
in
dtypes
:
# Row-major stride
row_stride
=
gguf_strides
(
shape
[
1
],
1
)
# Column-major stride
col_stride
=
gguf_strides
(
1
,
shape
[
0
])
# Test case with row-major input and output
x
=
random_tensor
(
shape
,
dtype
)
min_tensor
=
np
.
full
(
shape
,
-
1.0
,
dtype
=
dtype
)
max_tensor
=
np
.
full
(
shape
,
1.0
,
dtype
=
dtype
)
y
=
np
.
zeros
(
shape
,
dtype
=
dtype
)
test_cases
.
append
(
ClipTestCase
(
x
=
x
,
x_stride
=
row_stride
,
min_val
=
min_tensor
,
min_stride
=
row_stride
,
max_val
=
max_tensor
,
max_stride
=
row_stride
,
y
=
y
,
y_stride
=
row_stride
)
)
# Test case with column-major input and output
x
=
random_tensor
(
shape
,
dtype
)
min_tensor
=
np
.
full
(
shape
,
-
1.0
,
dtype
=
dtype
)
max_tensor
=
np
.
full
(
shape
,
1.0
,
dtype
=
dtype
)
y
=
np
.
zeros
(
shape
,
dtype
=
dtype
)
test_cases
.
append
(
ClipTestCase
(
x
=
x
,
x_stride
=
col_stride
,
min_val
=
min_tensor
,
min_stride
=
col_stride
,
max_val
=
max_tensor
,
max_stride
=
col_stride
,
y
=
y
,
y_stride
=
col_stride
)
)
# Test case with different strides for input and output
x
=
random_tensor
(
shape
,
dtype
)
min_tensor
=
np
.
full
(
shape
,
-
1.0
,
dtype
=
dtype
)
max_tensor
=
np
.
full
(
shape
,
1.0
,
dtype
=
dtype
)
y
=
np
.
zeros
(
shape
,
dtype
=
dtype
)
test_cases
.
append
(
ClipTestCase
(
x
=
x
,
x_stride
=
row_stride
,
min_val
=
min_tensor
,
min_stride
=
row_stride
,
max_val
=
max_tensor
,
max_stride
=
row_stride
,
y
=
y
,
y_stride
=
col_stride
)
)
# Add all test cases to the writer
test_writer
.
add_tests
(
test_cases
)
# Save the test cases to a GGUF file
test_writer
.
save
()
print
(
f
"Generated
{
len
(
test_cases
)
}
test cases for the Clip operator"
)
test/infiniop/attention.py
View file @
ce2c4813
...
...
@@ -215,7 +215,7 @@ if __name__ == "__main__":
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
"atol"
:
1e-
6
,
"rtol"
:
1e-
4
},
torch
.
float32
:
{
"atol"
:
1e-
5
,
"rtol"
:
1e-
3
},
}
DEBUG
=
False
...
...
@@ -268,6 +268,20 @@ if __name__ == "__main__":
None
,
# k_cache_stride
None
,
# v_cache_stride
),
(
28
,
# n_q_head
28
,
# n_kv_head
15
,
# seq_len
128
,
# head_dim
0
,
# pos
2048
,
# k_cache_buf_len
2048
,
# v_cache_buf_len
[
128
,
10752
,
1
],
# q_stride
[
128
,
10752
,
1
],
# k_stride
[
128
,
10752
,
1
],
# v_stride
[
128
,
3584
,
1
],
# k_cache_stride
[
128
,
3584
,
1
],
# v_cache_stride
),
]
args
=
get_args
()
lib
=
open_lib
()
...
...
test/infiniop/clip.py
0 → 100644
View file @
ce2c4813
#!/usr/bin/env python3
import
torch
import
ctypes
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
from
libinfiniop
import
(
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
open_lib
,
to_tensor
,
get_test_devices
,
check_error
,
rearrange_if_needed
,
create_workspace
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
)
from
enum
import
Enum
,
auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# shape, x_stride, y_stride, min_val, max_val
# 基本形状测试
((
10
,),
None
,
None
,
-
1.0
,
1.0
),
((
5
,
10
),
None
,
None
,
-
1.0
,
1.0
),
((
2
,
3
,
4
),
None
,
None
,
-
1.0
,
1.0
),
# 不同的min_val和max_val
((
10
,),
None
,
None
,
0.0
,
2.0
),
((
5
,
10
),
None
,
None
,
0.0
,
2.0
),
((
2
,
3
,
4
),
None
,
None
,
0.0
,
2.0
),
((
10
,),
None
,
None
,
-
2.0
,
0.0
),
((
5
,
10
),
None
,
None
,
-
2.0
,
0.0
),
((
2
,
3
,
4
),
None
,
None
,
-
2.0
,
0.0
),
# 奇怪形状测试
((
7
,
13
),
None
,
None
,
-
1.0
,
1.0
),
# 质数维度
((
3
,
5
,
7
),
None
,
None
,
-
1.0
,
1.0
),
# 三维质数
# 非标准形状测试
((
1
,
1
),
None
,
None
,
-
1.0
,
1.0
),
# 最小形状
((
100
,
100
),
None
,
None
,
-
1.0
,
1.0
),
# 大形状
((
16
,
16
,
16
),
None
,
None
,
-
1.0
,
1.0
),
# 大三维
# 极端值测试
((
10
,),
None
,
None
,
-
1000.0
,
1000.0
),
# 大范围
((
10
,),
None
,
None
,
-
0.001
,
0.001
),
# 小范围
((
10
,),
None
,
None
,
0.0
,
0.0
),
# min=max
# 特殊形状测试
((
0
,),
None
,
None
,
-
1.0
,
1.0
),
# 空张量
((
1
,
0
),
None
,
None
,
-
1.0
,
1.0
),
# 空维度
]
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
torch
.
float32
:
{
"atol"
:
1e-7
,
"rtol"
:
1e-6
},
}
class
Inplace
(
Enum
):
OUT_OF_PLACE
=
auto
()
INPLACE_X
=
auto
()
_INPLACE
=
[
Inplace
.
INPLACE_X
,
Inplace
.
OUT_OF_PLACE
,
]
_TEST_CASES
=
[
test_case
+
(
inplace_item
,)
for
test_case
in
_TEST_CASES_
for
inplace_item
in
_INPLACE
]
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
ClipDescriptor
(
Structure
):
_fields_
=
[(
"device_type"
,
c_int32
),
(
"device_id"
,
c_int32
)]
infiniopClipDescriptor_t
=
POINTER
(
ClipDescriptor
)
def
clip
(
x
,
min_val
,
max_val
):
return
torch
.
clamp
(
x
,
min_val
,
max_val
)
def
create_tensor_with_stride
(
shape
,
stride
,
dtype
,
device
):
"""Create a tensor with specific stride without using view() that might cause errors."""
x
=
torch
.
rand
(
shape
,
dtype
=
dtype
,
device
=
device
)
*
4.0
-
2.0
# Range: [-2, 2]
if
stride
is
None
:
return
x
if
len
(
shape
)
==
2
and
len
(
stride
)
==
2
:
if
stride
==
(
shape
[
1
],
1
):
return
x
.
contiguous
()
elif
stride
==
(
1
,
shape
[
0
]):
return
x
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
else
:
y
=
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
shape
[
0
]):
for
j
in
range
(
shape
[
1
]):
y
[
i
,
j
]
=
x
[
i
,
j
]
return
y
.
contiguous
()
return
x
def
test
(
lib
,
handle
,
torch_device
,
shape
,
x_stride
=
None
,
y_stride
=
None
,
min_val
=-
1.0
,
max_val
=
1.0
,
inplace
=
Inplace
.
OUT_OF_PLACE
,
dtype
=
torch
.
float32
,
):
print
(
f
"Testing Clip on
{
torch_device
}
with shape:
{
shape
}
x_stride:
{
x_stride
}
y_stride:
{
y_stride
}
"
f
"min_val:
{
min_val
}
max_val:
{
max_val
}
dtype:
{
dtype
}
inplace:
{
inplace
}
"
)
x
=
create_tensor_with_stride
(
shape
,
x_stride
,
dtype
,
torch_device
)
ans
=
clip
(
x
,
min_val
,
max_val
)
x
=
rearrange_if_needed
(
x
,
x_stride
)
x_tensor
=
to_tensor
(
x
,
lib
)
if
inplace
==
Inplace
.
INPLACE_X
:
y
=
x
y_tensor
=
x_tensor
else
:
y
=
torch
.
zeros
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
y
=
rearrange_if_needed
(
y
,
y_stride
)
y_tensor
=
to_tensor
(
y
,
lib
)
descriptor
=
infiniopClipDescriptor_t
()
check_error
(
lib
.
infiniopCreateClipDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
y_tensor
.
descriptor
,
x_tensor
.
descriptor
)
)
workspace_size
=
c_uint64
(
0
)
check_error
(
lib
.
infiniopGetClipWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
create_workspace
(
workspace_size
.
value
,
x
.
device
)
def
lib_clip
():
check_error
(
lib
.
infiniopClip
(
descriptor
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
workspace_size
.
value
,
y_tensor
.
data
,
x_tensor
.
data
,
c_float
(
min_val
),
c_float
(
max_val
),
None
,
)
)
lib_clip
()
# Now we can destroy the tensor descriptors
x_tensor
.
destroyDesc
(
lib
)
if
inplace
!=
Inplace
.
INPLACE_X
:
y_tensor
.
destroyDesc
(
lib
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
or
not
torch
.
allclose
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
):
print
(
"
\n
Expected:"
)
print
(
ans
)
print
(
"
\n
Actual:"
)
print
(
y
)
print
(
"
\n
Difference:"
)
print
(
torch
.
abs
(
y
-
ans
))
print
(
"
\n
Max difference:"
,
torch
.
max
(
torch
.
abs
(
y
-
ans
)).
item
())
debug
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
clip
(
x
,
min_val
,
max_val
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_clip
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroyClipDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
lib
=
open_lib
()
lib
.
infiniopCreateClipDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateClipDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopClipDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
lib
.
infiniopGetClipWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetClipWorkspaceSize
.
argtypes
=
[
infiniopClipDescriptor_t
,
POINTER
(
c_uint64
),
]
lib
.
infiniopClip
.
restype
=
c_int32
lib
.
infiniopClip
.
argtypes
=
[
infiniopClipDescriptor_t
,
c_void_p
,
c_uint64
,
c_void_p
,
c_void_p
,
c_float
,
c_float
,
c_void_p
,
]
lib
.
infiniopDestroyClipDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyClipDescriptor
.
argtypes
=
[
infiniopClipDescriptor_t
,
]
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/rope.py
View file @
ce2c4813
...
...
@@ -189,6 +189,9 @@ def test(
)
lib_rope
()
if
sync
is
not
None
:
sync
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment