Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bab3161b
"vscode:/vscode.git/clone" did not exist on "9e67148c56e0b193edd8de8c10b34579cfe50a3c"
Commit
bab3161b
authored
Jun 21, 2023
by
ltqin
Browse files
regular code
parent
cf9ef868
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
19 deletions
+8
-19
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
...dwise_batched_multihead_attention_bacckward_ydotygrad.hpp
+8
-19
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
View file @
bab3161b
...
@@ -78,18 +78,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -78,18 +78,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeORSGridDescriptor_MBlock_M
Repeat_NWave_MPerXdl
(
const
ORSGridDesc_M
&
lse_grid_desc_m
)
MakeORSGridDescriptor_MBlock_M
PerBlock
(
const
ORSGridDesc_M
&
lse_grid_desc_m
)
{
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
auto
lse_grid_desc_mblock_m
repeat_mwave_mperxdl
=
transform_tensor_descriptor
(
const
auto
lse_grid_desc_mblock_m
perblock
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
lse_grid_desc_mblock_m
repeat_mwave_mperxdl
;
return
lse_grid_desc_mblock_m
perblock
;
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
...
@@ -185,15 +185,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -185,15 +185,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
auto
y_thread_data_on_block_idx
=
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
// if(get_thread_global_1d_id() == 1)
// {
// printf("y_thread_data_on_block_idx:{ %d, %d, %d,%d}, get_thread_local_1d_id: %d\n",
// y_thread_data_on_block_idx[I0],
// y_thread_data_on_block_idx[I1],
// y_thread_data_on_block_idx[I2],
// y_thread_data_on_block_idx[I3],
// get_thread_local_1d_id());
// }
const
auto
y_thread_data_on_grid_idx
=
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
make_multi_index
(
block_work_idx_m
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
block_work_idx_m
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
...
@@ -253,14 +244,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -253,14 +244,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
oblock_idx
++
;
oblock_idx
++
;
}
while
(
oblock_idx
<
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
));
}
while
(
oblock_idx
<
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
));
auto
ors_grid_desc_mblock_m
repeat_mwave_mperxdl
=
auto
ors_grid_desc_mblock_m
perblock
=
MakeORSGridDescriptor_MBlock_M
Repeat_NWave_MPerXdl
(
ors_grid_desc_m
);
MakeORSGridDescriptor_MBlock_M
PerBlock
(
ors_grid_desc_m
);
auto
ors_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
ors_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatORS
,
FloatORS
,
FloatORS
,
FloatORS
,
decltype
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
ors_grid_desc_mblock_m
repeat_mwave_mperxdl
),
decltype
(
ors_grid_desc_mblock_m
perblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
>
,
Sequence
<
1
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
...
@@ -268,7 +259,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -268,7 +259,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
ors_grid_desc_mblock_m
repeat_mwave_mperxdl
,
false
>
{
ors_grid_desc_mblock_m
perblock
,
make_multi_index
(
block_work_idx_m
,
// mblock
make_multi_index
(
block_work_idx_m
,
// mblock
get_thread_local_1d_id
()),
// mperxdl
get_thread_local_1d_id
()),
// mperxdl
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
...
@@ -277,11 +268,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -277,11 +268,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
ors_thread_copy_vgpr_to_global
.
Run
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
,
ors_thread_copy_vgpr_to_global
.
Run
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
y_dot_ygrad_thread_accum_buf
,
y_dot_ygrad_thread_accum_buf
,
ors_grid_desc_mblock_m
repeat_mwave_mperxdl
,
ors_grid_desc_mblock_m
perblock
,
ors_grid_buf
);
ors_grid_buf
);
ignore
=
ors_thread_copy_vgpr_to_global
;
ignore
=
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment