Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2ac0eefd
Commit
2ac0eefd
authored
Jan 15, 2023
by
guangzlu
Browse files
added dropout rescale into grouped_gemm_softmax_gemm
parent
6926effa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
11 deletions
+22
-11
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
+5
-6
No files found.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
2ac0eefd
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
namespace
ck
{
namespace
ck
{
template
<
typename
ThreadSliceDesc_M_K
>
template
<
typename
DataType
,
typename
ThreadSliceDesc_M_K
>
struct
BlockwiseDropout
struct
BlockwiseDropout
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -18,13 +18,14 @@ struct BlockwiseDropout
...
@@ -18,13 +18,14 @@ struct BlockwiseDropout
template
<
typename
CThreadBuffer
>
template
<
typename
CThreadBuffer
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ushort
p_dropout_16bits
,
ck
::
philox
ph
,
ck
::
philox
ph
,
const
int
repeat_index
,
const
int
repeat_index
,
const
int
total_repeats
)
const
int
total_repeats
)
{
{
auto
if_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
float
(
0
);
};
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
8
;
...
@@ -45,11 +46,14 @@ struct BlockwiseDropout
...
@@ -45,11 +46,14 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
if
_dropout
(
tmp
[
tmp_index
]
<
p_dropout_16bits
,
in_thread_buf
(
offset
));
execute
_dropout
(
tmp
[
tmp_index
]
<
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
}
}
ushort
p_dropout_16bits
;
DataType
p_dropout_rescale
;
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
View file @
2ac0eefd
...
@@ -24,6 +24,7 @@ namespace tensor_operation {
...
@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmAccDataType
,
typename
GroupKernelArg
,
typename
GroupKernelArg
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -45,6 +46,7 @@ __global__ void
...
@@ -45,6 +46,7 @@ __global__ void
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
)
const
unsigned
long
long
seed
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
@@ -52,7 +54,7 @@ __global__ void
...
@@ -52,7 +54,7 @@ __global__ void
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
ck
::
philox
ph
(
seed
,
0
,
block_id
);
ck
::
philox
ph
(
seed
,
0
,
block_id
*
4
);
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
@@ -111,6 +113,7 @@ __global__ void
...
@@ -111,6 +113,7 @@ __global__ void
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
);
ph
);
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
...
@@ -642,6 +645,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -642,6 +645,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
}
}
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
...
@@ -659,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -659,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
p_dropout_
;
float
p_dropout_
;
ushort
p_dropout_in_16bits_
;
ushort
p_dropout_in_16bits_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
GemmAccDataType
p_dropout_rescale_
;
bool
is_dropout_
;
bool
is_dropout_
;
};
};
...
@@ -695,6 +701,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -695,6 +701,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GemmAccDataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -718,6 +725,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -718,6 +725,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg
.
b1_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
);
arg
.
seed_
);
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
View file @
2ac0eefd
...
@@ -383,6 +383,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -383,6 +383,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
ph
)
ck
::
philox
ph
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -728,7 +729,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -728,7 +729,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_16bits
,
p_dropout_rescale
};
const
index_t
num_gemm1_k_block_outer_loop
=
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
@@ -873,11 +875,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -873,11 +875,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
if
constexpr
(
IsDropout
)
// dropout
{
{
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
blockwise_dropout
.
ApplyDropout
(
p_dropout_in_16bits
,
acc_thread_buf
,
ph
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
ph
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
}
}
// TODO: may convert to log domain
// TODO: may convert to log domain
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment