Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2f2f5490
Commit
2f2f5490
authored
Sep 09, 2023
by
letaoqin
Browse files
chech bias pointer is nullptr
parent
e0f595de
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
61 deletions
+70
-61
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+70
-61
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
2f2f5490
...
@@ -2336,50 +2336,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2336,50 +2336,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
if
(
p_d0_grid
!=
nullptr
)
{
static
constexpr
auto
&
c_thread_desc
=
s_blockwise_gemm
.
GetCThreadDesc
();
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Operator
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
d0_block_copy_global_to_lds
.
RunWrite
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
d0_thread_copy_lds_to_vgpr
.
Run
(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
d0_block_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
D0Operator
::
d0_thread_desc_
,
d0_block_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
D0Operator
::
d0_thread_desc_
,
d0_thread_buf
);
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
// bias add
constexpr
index_t
c_offset
=
static_for
<
0
,
d0_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
I0
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
// do MNK padding or upper triangular masking
// do MNK padding or upper triangular masking
...
@@ -2608,37 +2613,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2608,37 +2613,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// output bias grad
// output bias grad
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
if
(
p_d0grad_grid
!=
nullptr
)
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
{
auto
d0grad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0grad_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0grad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
static_cast
<
D0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
d0_block_space_offset
,
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
D0M0
,
1
>
{}([
&
](
auto
mr
)
{
d0grad_thread_copy_vgpr_to_lds
.
Run
(
d0grad_thread_copy_vgpr_to_lds
.
Run
(
D0Operator
::
d0_thread_desc_
,
D0Operator
::
d0_thread_desc_
,
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
mr
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
sgrad_thread_buf
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
D0Operator
::
d0_block_vgpr_desc_n0_n1_m0_m1_m2
,
d0grad_block_buf
);
d0grad_block_buf
);
block_sync_lds
();
block_sync_lds
();
// write data from lds to global
// write data from lds to global
d0_block_copy_lds_to_global
.
Run
(
d0_block_copy_lds_to_global
.
Run
(
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
D0Operator
::
d0_block_global_desc_m0_n0_m1_m2_n1_m3
,
d0grad_block_buf
,
d0grad_block_buf
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0grad_grid_buf
,
d0grad_grid_buf
,
I0
);
I0
);
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
});
});
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_block_copy_lds_to_global
.
MoveDstSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
-
D0M0
.
value
,
0
,
0
,
0
));
}
}
}
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment