Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5beab8c0
Unverified
Commit
5beab8c0
authored
May 22, 2025
by
PanZezhong1725
Committed by
GitHub
May 22, 2025
Browse files
Merge pull request #229 from InfiniTensor/issue/191/fix
issue/191/fix 为attention增加对齐,修复cuda causal softmax
parents
2e624a8e
b79f2607
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
66 deletions
+93
-66
src/infiniop/ops/attention/operator.cc
src/infiniop/ops/attention/operator.cc
+69
-64
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
...nfiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
+1
-1
src/utils.h
src/utils.h
+8
-0
test/infiniop/attention.py
test/infiniop/attention.py
+15
-1
No files found.
src/infiniop/ops/attention/operator.cc
View file @
5beab8c0
...
...
@@ -20,15 +20,14 @@ struct InfiniopAttentionDescriptor {
infiniopGemmDescriptor_t
matmul_desc1
;
infiniopGemmDescriptor_t
matmul_desc2
;
infiniopCausalSoftmaxDescriptor_t
softmax_desc
;
uint64_t
workspace_size
;
uint64_t
rearranged_q_size
;
uint64_t
matmul1_workspace_size
;
uint64_t
matmul1_tensor_size
;
uint64_t
matmul2_workspace_size
;
uint64_t
matmul2_tensor_size
;
uint64_t
softmax_workspace_size
;
uint64_t
k_cache_offset
;
uint64_t
v_cache_offset
;
size_t
workspace_size
;
size_t
op_workspace_offset
;
size_t
op_workspace_size
;
size_t
q_cont_offset
;
size_t
att_score_offset
;
size_t
att_val_offset
;
size_t
k_cache_offset
;
size_t
v_cache_offset
;
float
qk_alpha
;
};
...
...
@@ -40,7 +39,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
uint64
_t
pos
)
{
size
_t
pos
)
{
if
(
out_desc
->
ndim
()
!=
3
||
q_desc
->
ndim
()
!=
3
||
k_desc
->
ndim
()
!=
3
||
v_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
!=
3
||
v_cache_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
...
...
@@ -53,13 +52,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
uint64_t
n_q_head
=
q_desc
->
shape
()[
0
];
uint64_t
seq_len
=
q_desc
->
shape
()[
1
];
uint64_t
head_dim
=
q_desc
->
shape
()[
2
];
uint64_t
hidden_size
=
n_q_head
*
head_dim
;
uint64_t
n_kv_head
=
k_desc
->
shape
()[
0
];
uint64_t
total_seq_len
=
seq_len
+
pos
;
uint64_t
n_group
=
n_q_head
/
n_kv_head
;
size_t
n_q_head
=
q_desc
->
shape
()[
0
];
size_t
seq_len
=
q_desc
->
shape
()[
1
];
size_t
head_dim
=
q_desc
->
shape
()[
2
];
size_t
hidden_size
=
n_q_head
*
head_dim
;
size_t
n_kv_head
=
k_desc
->
shape
()[
0
];
size_t
total_seq_len
=
seq_len
+
pos
;
size_t
n_group
=
n_q_head
/
n_kv_head
;
size_t
alignment
=
256
;
if
(
out_desc
->
shape
()[
0
]
!=
seq_len
||
out_desc
->
shape
()[
1
]
!=
n_q_head
||
out_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
...
...
@@ -98,12 +98,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_v
,
dst_v_desc
,
v_desc
));
infiniopRearrangeDescriptor_t
rearrange_desc_q
=
nullptr
;
uint64_t
rearranged_q
_size
=
0
;
size_t
q_cont
_size
=
0
;
infiniopTensorDescriptor_t
rearranged_q_desc
;
// Rearrange q into contiguous
if
(
!
q_desc
->
isContiguous
(
0
,
1
))
{
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
rearranged_q_desc
,
3
,
q_desc
->
shape
().
data
(),
nullptr
,
q_desc
->
dtype
()));
rearranged_q_size
=
rearranged_q_desc
->
numel
()
*
infiniSizeOf
(
rearranged_q_desc
->
dtype
());
q_cont_size
=
utils
::
align
(
rearranged_q_desc
->
numel
()
*
infiniSizeOf
(
rearranged_q_desc
->
dtype
())
,
alignment
)
;
rearrange_desc_q
=
new
InfiniopDescriptor
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_q
,
rearranged_q_desc
,
q_desc
));
}
...
...
@@ -116,12 +116,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
TRANSFORM_TENSOR_DESC
(
reshaped_q_desc
,
dimMerge
(
1
,
2
));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t
full_k_desc
;
uint64
_t
full_k_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
size
_t
full_k_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
full_k_desc
,
3
,
full_k_shape
,
k_cache_desc
->
strides
().
data
(),
k_cache_desc
->
dtype
()));
TRANSFORM_TENSOR_DESC
(
full_k_desc
,
dimPermute
({
0
,
2
,
1
}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t
qk_desc
;
uint64
_t
qk_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
total_seq_len
};
size
_t
qk_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
total_seq_len
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
qk_desc
,
3
,
qk_shape
,
nullptr
,
q_desc
->
dtype
()));
// matmul1_desc
// qk_alpha
...
...
@@ -129,10 +129,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopGemmDescriptor_t
matmul1_desc
;
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul1_desc
,
qk_desc
,
reshaped_q_desc
,
full_k_desc
));
// matmul1 workspace size
uint64
_t
matmul1_workspace_size
;
size
_t
matmul1_workspace_size
;
CHECK_STATUS
(
infiniopGetGemmWorkspaceSize
(
matmul1_desc
,
&
matmul1_workspace_size
));
// matmul1 tensor size
uint64_t
matmul1_tensor_size
=
qk_desc
->
numel
()
*
infiniSizeOf
(
qk_desc
->
dtype
());
matmul1_workspace_size
=
utils
::
align
(
matmul1_workspace_size
,
alignment
);
// attention score tensor size
size_t
attn_score_size
=
utils
::
align
(
qk_desc
->
numel
()
*
infiniSizeOf
(
qk_desc
->
dtype
()),
alignment
);
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
...
...
@@ -141,8 +142,9 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopCausalSoftmaxDescriptor_t
softmax_desc
;
CHECK_STATUS
(
infiniopCreateCausalSoftmaxDescriptor
(
handle
,
&
softmax_desc
,
qk_desc
,
qk_desc
));
// softmax workspace size
uint64
_t
softmax_workspace_size
;
size
_t
softmax_workspace_size
;
CHECK_STATUS
(
infiniopGetCausalSoftmaxWorkspaceSize
(
softmax_desc
,
&
softmax_workspace_size
));
softmax_workspace_size
=
utils
::
align
(
softmax_workspace_size
,
alignment
);
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
...
...
@@ -150,41 +152,44 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimSplit
(
0
,
{
n_kv_head
,
n_group
}));
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimMerge
(
1
,
2
));
infiniopTensorDescriptor_t
full_v_desc
;
uint64
_t
full_v_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
size
_t
full_v_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
full_v_desc
,
3
,
full_v_shape
,
v_cache_desc
->
strides
().
data
(),
v_cache_desc
->
dtype
()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t
temp_out
_desc
;
uint64
_t
temp_out_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
temp_out
_desc
,
3
,
temp_out_shape
,
nullptr
,
q_desc
->
dtype
()));
infiniopTensorDescriptor_t
att_val
_desc
;
size
_t
temp_out_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
att_val
_desc
,
3
,
temp_out_shape
,
nullptr
,
q_desc
->
dtype
()));
// matmul2_desc
infiniopGemmDescriptor_t
matmul2_desc
;
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul2_desc
,
temp_out
_desc
,
qk_desc
,
full_v_desc
));
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul2_desc
,
att_val
_desc
,
qk_desc
,
full_v_desc
));
// matmul2 workspace size
uint64
_t
matmul2_workspace_size
;
size
_t
matmul2_workspace_size
;
CHECK_STATUS
(
infiniopGetGemmWorkspaceSize
(
matmul2_desc
,
&
matmul2_workspace_size
));
// matmul2 tensor size
uint64_t
matmul2_tensor_size
=
temp_out_desc
->
numel
()
*
infiniSizeOf
(
temp_out_desc
->
dtype
());
matmul2_workspace_size
=
utils
::
align
(
matmul2_workspace_size
,
alignment
);
// attention value tensor size
size_t
att_val_size
=
utils
::
align
(
att_val_desc
->
numel
()
*
infiniSizeOf
(
att_val_desc
->
dtype
()),
alignment
);
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC
(
temp_out
_desc
,
dimSplit
(
1
,
{
n_group
,
seq_len
}));
TRANSFORM_TENSOR_DESC
(
temp_out
_desc
,
dimMerge
(
0
,
1
));
TRANSFORM_TENSOR_DESC
(
temp_out
_desc
,
dimPermute
({
1
,
0
,
2
}));
TRANSFORM_TENSOR_DESC
(
att_val
_desc
,
dimSplit
(
1
,
{
n_group
,
seq_len
}));
TRANSFORM_TENSOR_DESC
(
att_val
_desc
,
dimMerge
(
0
,
1
));
TRANSFORM_TENSOR_DESC
(
att_val
_desc
,
dimPermute
({
1
,
0
,
2
}));
infiniopRearrangeDescriptor_t
rearrange_desc_out
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_out
,
out_desc
,
temp_out
_desc
));
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_out
,
out_desc
,
att_val
_desc
));
// workspace size
uint64_t
workspace_size
=
rearranged_q_size
+
std
::
max
(
std
::
max
(
matmul1_workspace_size
+
matmul1_tensor_size
,
matmul1_tensor_size
+
softmax_workspace_size
),
matmul1_tensor_size
+
matmul2_workspace_size
+
matmul2_tensor_size
);
size_t
op_workspace_size
=
utils
::
align
(
std
::
max
(
std
::
max
(
matmul1_workspace_size
,
matmul2_workspace_size
),
softmax_workspace_size
),
alignment
);
size_t
temp_tensors_size
=
attn_score_size
+
std
::
max
(
q_cont_size
,
att_val_size
);
size_t
workspace_size
=
temp_tensors_size
+
op_workspace_size
;
// k_cache_offset
uint64
_t
k_cache_offset
=
0
;
size
_t
k_cache_offset
=
0
;
if
(
pos
>
0
)
{
k_cache_offset
=
pos
*
k_cache_desc
->
getByteStrides
()[
1
];
}
// v_cache_offset
uint64
_t
v_cache_offset
=
0
;
size
_t
v_cache_offset
=
0
;
if
(
pos
>
0
)
{
v_cache_offset
=
pos
*
v_cache_desc
->
getByteStrides
()[
1
];
}
...
...
@@ -200,12 +205,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
matmul2_desc
,
softmax_desc
,
workspace_size
,
rearranged_q_size
,
matmul1_workspace_size
,
matmul1_tensor_size
,
matmul2_workspace_size
,
matmul2_tensor_size
,
softmax_workspace_size
,
temp_tensors_size
,
op_workspace_size
,
attn_score_size
,
0
,
attn_score_size
,
k_cache_offset
,
v_cache_offset
,
1.
f
/
std
::
sqrt
(
float
(
head_dim
)),
...
...
@@ -214,14 +218,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopGetAttentionWorkspaceSize
(
infiniopAttentionDescriptor_t
desc
,
uint64
_t
*
size
)
{
__C
__export
infiniStatus_t
infiniopGetAttentionWorkspaceSize
(
infiniopAttentionDescriptor_t
desc
,
size
_t
*
size
)
{
*
size
=
((
InfiniopAttentionDescriptor
*
)
desc
)
->
workspace_size
;
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopAttention
(
infiniopAttentionDescriptor_t
desc_
,
void
*
workspace
,
uint64
_t
workspace_size
,
void
*
workspace
_
,
size
_t
workspace_size
_
,
void
*
out
,
void
const
*
q
,
void
const
*
k
,
...
...
@@ -230,11 +234,14 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc
void
*
v_cache
,
void
*
stream
)
{
auto
desc
=
(
InfiniopAttentionDescriptor
*
)
desc_
;
void
*
workspace_
=
workspace
;
if
(
workspace_size
<
desc
->
workspace_size
)
{
if
(
workspace_size_
<
desc
->
workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
// STATUS_MEMORY_NOT_ALLOCATED
}
void
*
workspace
=
(
char
*
)
workspace_
+
desc
->
op_workspace_offset
;
size_t
workspace_size
=
desc
->
op_workspace_size
;
void
*
att_score
=
(
char
*
)
workspace_
+
desc
->
att_score_offset
;
void
*
att_val
=
(
char
*
)
workspace_
+
desc
->
att_val_offset
;
void
const
*
q_
=
q
;
// concat k and v to k_cache and v_cache
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_k
,
(
char
*
)
k_cache
+
desc
->
k_cache_offset
,
k
,
stream
));
...
...
@@ -243,28 +250,26 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc
(
char
*
)
v_cache
+
desc
->
v_cache_offset
,
v
,
stream
));
// rearrange q into contiguous
void
const
*
_q
=
q
;
if
(
desc
->
rearrange_desc_q
)
{
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_q
,
(
char
*
)
workspace_
,
q
,
stream
))
;
_q
=
workspace_
;
workspace_
=
(
char
*
)
workspace_
+
desc
->
rearranged_q_size
;
void
*
q_cont
=
(
char
*
)
workspace_
+
desc
->
q_cont_offset
;
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_q
,
q_cont
,
q
,
stream
))
;
q_
=
q_cont
;
}
// matmul1: q * full_k
CHECK_STATUS
(
infiniopGemm
(
desc
->
matmul_desc1
,
(
char
*
)
workspace
_
+
desc
->
matmul1_tensor_size
,
workspace_size
-
desc
->
matmul1_tensor
_size
,
workspace_
,
_q
,
k_cache
,
desc
->
qk_alpha
,
0.0
,
stream
));
workspace
,
workspace
_size
,
att_score
,
q_
,
k_cache
,
desc
->
qk_alpha
,
0.0
,
stream
));
// softmax(qk)
CHECK_STATUS
(
infiniopCausalSoftmax
(
desc
->
softmax_desc
,
(
char
*
)
workspace
_
+
desc
->
matmul1_tensor_size
,
workspace_size
-
desc
->
matmul1_tensor
_size
,
workspace_
,
workspace_
,
stream
));
workspace
,
workspace
_size
,
att_score
,
att_score
,
stream
));
// matmul2: softmax(qk) * full_v
CHECK_STATUS
(
infiniopGemm
(
desc
->
matmul_desc2
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
+
desc
->
matmul2_tensor_size
,
workspace_size
-
desc
->
matmul1_tensor_size
-
desc
->
matmul2_tensor_size
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
,
workspace_
,
v_cache
,
1.0
,
0.0
,
stream
));
workspace
,
workspace_size
,
att_val
,
att_score
,
v_cache
,
1.0
,
0.0
,
stream
));
// rearrange out
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_out
,
out
,
(
char
*
)
workspace_
+
desc
->
matmul1_tensor_size
,
stream
));
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_out
,
out
,
att_val
,
stream
));
return
INFINI_STATUS_SUCCESS
;
}
...
...
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
View file @
5beab8c0
...
...
@@ -18,7 +18,7 @@ INFINIOP_CUDA_KERNEL causalSoftmax(
// [Reduce] Find max value in each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
);
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
-
height
+
1
+
blockIdx
.
x
);
if
(
threadIdx
.
x
==
0
)
{
max_
=
max_0
;
}
...
...
src/utils.h
View file @
5beab8c0
...
...
@@ -100,4 +100,12 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
namespace
utils
{
inline
size_t
align
(
size_t
size
,
size_t
alignment
)
{
return
(
size
+
alignment
-
1
)
&
~
(
alignment
-
1
);
}
}
// namespace utils
#endif
test/infiniop/attention.py
View file @
5beab8c0
...
...
@@ -215,7 +215,7 @@ if __name__ == "__main__":
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
"atol"
:
1e-
6
,
"rtol"
:
1e-
4
},
torch
.
float32
:
{
"atol"
:
1e-
5
,
"rtol"
:
1e-
3
},
}
DEBUG
=
False
...
...
@@ -268,6 +268,20 @@ if __name__ == "__main__":
None
,
# k_cache_stride
None
,
# v_cache_stride
),
(
28
,
# n_q_head
28
,
# n_kv_head
15
,
# seq_len
128
,
# head_dim
0
,
# pos
2048
,
# k_cache_buf_len
2048
,
# v_cache_buf_len
[
128
,
10752
,
1
],
# q_stride
[
128
,
10752
,
1
],
# k_stride
[
128
,
10752
,
1
],
# v_stride
[
128
,
3584
,
1
],
# k_cache_stride
[
128
,
3584
,
1
],
# v_cache_stride
),
]
args
=
get_args
()
lib
=
open_lib
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment