Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
29d5cbac
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "93a2b097fe742e7251d7dd712d5b0e9da3d46f58"
Commit
29d5cbac
authored
Oct 23, 2023
by
letaoqin
Browse files
remove some ignores
parent
72faed1c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
9 deletions
+3
-9
example/52_flash_atten_bias/run_batched_multihead_attention_infer.inc
...lash_atten_bias/run_batched_multihead_attention_infer.inc
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+2
-8
No files found.
example/52_flash_atten_bias/run_batched_multihead_attention_infer.inc
View file @
29d5cbac
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
...
@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
View file @
29d5cbac
...
@@ -538,7 +538,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -538,7 +538,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
const
C0MatrixMask
&
c0_matrix_mask
)
{
{
ignore
=
d0_grid_desc_m0_n0_n1_n2_m1_n3
;
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -899,7 +898,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -899,7 +898,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// gemm1 K loop
// gemm1 K loop
ignore
=
wave_m_n_id
;
auto
d0_block_copy_global_to_lds
=
typename
D0Operator
::
D0BlockwiseCopyGlobalToLds
(
auto
d0_block_copy_global_to_lds
=
typename
D0Operator
::
D0BlockwiseCopyGlobalToLds
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
0
,
0
,
0
,
0
),
...
@@ -1004,7 +1002,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1004,7 +1002,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
ignore
=
d0_thread_copy_lds_to_vgpr
;
if
(
p_d0_grid
!=
nullptr
)
if
(
p_d0_grid
!=
nullptr
)
{
{
static
constexpr
auto
&
c_thread_desc
=
blockwise_gemm
.
GetCThreadDesc
();
static
constexpr
auto
&
c_thread_desc
=
blockwise_gemm
.
GetCThreadDesc
();
...
@@ -1018,10 +1015,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1018,10 +1015,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
c_thread_desc
;
ignore
=
d0_grid_buf
;
ignore
=
d0_block_buf
;
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0N0
,
1
>
{}([
&
](
auto
nr
)
{
static_for
<
0
,
D0N0
,
1
>
{}([
&
](
auto
nr
)
{
// load data to lds
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_n1_n2_m1_n3
,
...
@@ -1045,7 +1039,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
...
@@ -1045,7 +1039,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
,
i
));
c_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
nr
,
i
));
ignore
=
c_offset
;
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
acc_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment