Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ff88ffa4
Commit
ff88ffa4
authored
May 29, 2023
by
guangzlu
Browse files
bwd pass for v4
parent
01073007
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
182 additions
and
8 deletions
+182
-8
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+74
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+106
-5
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
ff88ffa4
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -497,7 +497,7 @@ int run(int argc, char* argv[])
...
@@ -497,7 +497,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
0
;
float
p_drop
=
0.
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
ff88ffa4
...
@@ -145,6 +145,15 @@ struct BlockwiseDropout
...
@@ -145,6 +145,15 @@ struct BlockwiseDropout
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
}
ushort
tmp_id
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
;
}
}
block_sync_lds
();
block_sync_lds
();
int
tmp_index
=
0
;
int
tmp_index
=
0
;
...
@@ -153,7 +162,71 @@ struct BlockwiseDropout
...
@@ -153,7 +162,71 @@ struct BlockwiseDropout
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp_id
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v2
(
CThreadBuffer
&
in_thread_buf
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
});
});
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
>
__host__
__device__
void
GenerateZMatrix
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
ushort
tmp_id
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
*
MRaw
;
}
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
z_thread_buf
(
offset
)
=
tmp_id
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
ff88ffa4
...
@@ -1600,13 +1600,83 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1600,13 +1600,83 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
z_tenor_buffer
;
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_tenor_buffer
.
Clear
();
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer_tmp
;
z_tenor_buffer_tmp
.
Clear
();
// z matrix global desc
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_grid_buf_tmp
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// tmp buffer for shuffle
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_tmp_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m3
,
// NGroupNum
m4
,
// NInputNum
n2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
0
,
// MBlockId
block_work_idx_n
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
0
,
// MPerXdl
wave_m_n_id
[
I0
],
// group
0
,
// NInputIndex
wave_m_n_id
[
I1
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_tmp_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ZDataType
,
ushort
,
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
Sequence
<
I1
,
I1
,
m0
,
n0
,
m1
,
n1
,
m2
,
m3
,
m4
,
n2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
0
,
// MBlockId
block_work_idx_n
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
0
,
// MPerXdl
wave_m_n_id
[
I0
],
// group
0
,
// NInputIndex
wave_m_n_id
[
I1
])};
auto
z_thread_copy_vgpr_to_global
=
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
ZDataType
,
...
@@ -1986,9 +2056,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1986,9 +2056,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
auto
global_elem_id
_raw
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
MRaw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
// printf("global_elem_id is %d \n", global_elem_id);
//}
//}
...
@@ -2001,10 +2074,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2001,10 +2074,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// printf("id_step is %d \n", id_step);
// printf("id_step is %d \n", id_step);
//}
//}
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
// dropout
decltype
(
z_tenor_buffer
),
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
// generate random number
blockwise_dropout
.
template
GenerateZMatrix
<
decltype
(
z_tenor_buffer_tmp
)>(
ph
,
global_elem_id
,
z_tenor_buffer_tmp
,
MRaw
);
z_tmp_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer_tmp
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf_tmp
);
z_tmp_thread_copy_global_to_vgpr
.
Run
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf_tmp
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
);
blockwise_dropout
.
template
ApplyDropout_v2
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -2079,6 +2173,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2079,6 +2173,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
ignore
=
z_grid_buf
;
ignore
=
z_grid_buf
;
ignore
=
z_grid_buf_tmp
;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
);
...
@@ -2266,6 +2361,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2266,6 +2361,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Gemm1
::
b_block_reset_copy_step
);
// rewind M
Gemm1
::
b_block_reset_copy_step
);
// rewind M
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
z_tmp_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
z_tmp_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment