Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dad06b35
Commit
dad06b35
authored
Jun 20, 2023
by
danyao12
Browse files
code cleanup for fwd dropout
parent
51e102e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
96 deletions
+10
-96
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+1
-28
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+9
-68
No files found.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
dad06b35
...
@@ -122,8 +122,7 @@ struct BlockwiseDropout
...
@@ -122,8 +122,7 @@ struct BlockwiseDropout
});
});
}
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
index_t
element_global_1d_id
,
...
@@ -185,15 +184,6 @@ struct BlockwiseDropout
...
@@ -185,15 +184,6 @@ struct BlockwiseDropout
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds
();
block_sync_lds
();
int
tmp_index
=
0
;
int
tmp_index
=
0
;
...
@@ -227,9 +217,6 @@ struct BlockwiseDropout
...
@@ -227,9 +217,6 @@ struct BlockwiseDropout
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
});
});
});
});
}
}
...
@@ -240,11 +227,6 @@ struct BlockwiseDropout
...
@@ -240,11 +227,6 @@ struct BlockwiseDropout
index_t
element_global_1d_id
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
ZThreadBuffer
&
z_thread_buf
)
{
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
int
philox_calls
=
tmp_size
/
4
;
...
@@ -255,15 +237,6 @@ struct BlockwiseDropout
...
@@ -255,15 +237,6 @@ struct BlockwiseDropout
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds
();
block_sync_lds
();
int
tmp_index
=
0
;
int
tmp_index
=
0
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
dad06b35
...
@@ -145,9 +145,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -145,9 +145,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
group_size
=
mfma
.
group_size
;
return
index_t
(
ceil
(
float
(
size
)
/
N5
)
*
N5
)
;
return
math
::
integer_divide_ceil
(
size
,
group_size
)
*
group_size
;
}
}
__device__
static
auto
GetGemm0WaveIdx
()
__device__
static
auto
GetGemm0WaveIdx
()
...
@@ -263,7 +263,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -263,7 +263,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
const
index_t
z_block_bytes_end
=
const
index_t
z_block_bytes_end
=
SharedMemTrait
::
z_shuffle_block_space_size
*
sizeof
(
ZDataType
);
SharedMemTrait
::
z_shuffle_block_space_size
*
sizeof
(
ushort
);
return
math
::
max
(
gemm0_bytes_end
,
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
gemm1_bytes_end
,
...
@@ -871,14 +871,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -871,14 +871,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// z vgpr copy to global
// z vgpr copy to global
//
//
// z matrix threadwise desc
// z matrix threadwise desc
// if(get_thread_global_1d_id()==0){
// printf("m2 is %d \n",m2.value);
// printf("n2 is %d \n",n2.value);
// printf("n3 is %d \n",n3.value);
// printf("n4 is %d \n",n4.value);
//}
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
// for blockwise copy
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
n0
,
// NRepeat
n0
,
// NRepeat
...
@@ -915,8 +907,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -915,8 +907,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// constexpr auto z_block_lengths = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths();
constexpr
auto
zM0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
zM0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
zN0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
zN0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
zM1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
zM1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
...
@@ -954,17 +944,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -954,17 +944,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
6
>
{},
Sequence
<
6
>
{},
Sequence
<
8
>
{}));
Sequence
<
8
>
{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
nsigned
short
,
ushort
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
.
GetElementSpaceSize
(),
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
// z buffer after shuffle
z_tensor_buffer
;
z_tensor_buffer
.
Clear
();
z_tensor_buffer
.
Clear
();
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
...
@@ -972,50 +958,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -972,50 +958,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static_cast
<
ZDataType
*>
(
p_shared
),
static_cast
<
ZDataType
*>
(
p_shared
),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
// if(get_thread_global_1d_id()==0){
// printf("z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize() is %ld \n",
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize().value);
// printf("z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n", z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// printf("z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld
// \n",z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize().value);
// printf("z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n",z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// }
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// if(get_block_1d_id()==0){
// if(get_thread_local_1d_id()==0){
// printf("tid = 0 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==1){
// printf("tid = 1 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==2){
// printf("tid = 2 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==3){
// printf("tid = 3 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==32){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==64){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
auto
z_tmp_thread_copy_vgpr_to_lds
=
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1045,7 +993,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1045,7 +993,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ZDataType
,
ushort
,
ushort
,
ushort
,
decltype
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
),
decltype
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
),
decltype
(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
),
decltype
(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4
),
...
@@ -1111,10 +1059,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1111,10 +1059,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
(
c0_matrix_mask
.
IsTileSkippable
(
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
{
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
continue
;
continue
;
}
}
// gemm0
// gemm0
...
@@ -1199,10 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1199,10 +1143,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
global_elem_id
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
auto
global_elem_id
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
n_global
;
// unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tensor_buffer
)>(
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tensor_buffer
)>(
ph
,
global_elem_id
,
z_tensor_buffer
);
ph
,
global_elem_id
,
z_tensor_buffer
);
...
@@ -1224,6 +1164,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1224,6 +1164,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false
>(
acc_thread_buf
,
false
>(
acc_thread_buf
,
z_tensor_buffer
);
z_tensor_buffer
);
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
// ignore = z_tensor_buffer;
// ignore = z_tensor_buffer;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment