Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9fe6407e
Commit
9fe6407e
authored
Jan 14, 2023
by
guangzlu
Browse files
added dropout into fla training
parent
393470f5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
124 additions
and
48 deletions
+124
-48
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
...emm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
+23
-21
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+78
-24
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
+20
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
100755 → 100644
View file @
9fe6407e
...
...
@@ -140,7 +140,8 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
}
switch
(
init_method
)
...
...
@@ -173,7 +174,6 @@ int run(int argc, char* argv[])
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
()));
b0_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
...
...
@@ -217,7 +217,8 @@ int run(int argc, char* argv[])
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
c_element_op
,
0
);
// dropout ratio
// specify workspace for problem_desc
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
...
@@ -307,7 +308,8 @@ int run(int argc, char* argv[])
// softmax
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
acc0_g_m_n
,
a1_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m_host_result
);
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
acc0_g_m_n
,
a1_g_m_n
,
1
,
0
,
{
2
},
&
lse_g_m_host_result
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
...
...
@@ -354,9 +356,9 @@ int run(int argc, char* argv[])
}
// bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData);
bool
pass_
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
// ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData);
bool
pass_
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results c!"
,
rtol
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
9fe6407e
...
...
@@ -127,7 +127,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
100755 → 100644
View file @
9fe6407e
...
...
@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
@@ -29,7 +30,8 @@ template <typename GridwiseGemm,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
IsDropout
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -41,13 +43,17 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
,
const
ushort
p_dropout_in_16bits
,
const
unsigned
long
long
seed
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
ck
::
philox
ph
(
seed
,
0
,
block_id
);
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
...
@@ -82,10 +88,10 @@ __global__ void
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
...
...
@@ -103,7 +109,9 @@ __global__ void
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
);
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_16bits
,
ph
);
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
...
...
@@ -506,12 +514,15 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
float
p_dropout
,
unsigned
long
long
seed
)
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
},
seed_
(
seed
)
{
// TODO ANT: implement bias addition
group_count_
=
problem_desc_vec
.
size
();
...
...
@@ -547,7 +558,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
lse_grid_desc_m
=
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
lse_gs_ms_lengths
[
NumDimG
]);
const
auto
lse_grid_desc_m
=
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
lse_gs_ms_lengths
[
NumDimG
]);
const
auto
a_grid_desc_g_m_k
=
Transform
::
MakeAGridDescriptor_G_M_K
(
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
...
...
@@ -571,7 +583,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
// batch stride
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
...
...
@@ -622,6 +638,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
});
}
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
}
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
...
...
@@ -635,6 +655,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
float
p_dropout_
;
ushort
p_dropout_in_16bits_
;
unsigned
long
long
seed_
;
bool
is_dropout_
;
};
// Invoker
...
...
@@ -667,7 +692,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GroupKernelArg
,
...
...
@@ -676,7 +701,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
is_dropout_
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -690,18 +716,38 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
,
arg
.
p_dropout_in_16bits_
,
arg
.
seed_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
if
(
!
some_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
{
...
...
@@ -839,7 +885,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
{
return
Argument
{
p_a_vec
,
p_b_vec
,
...
...
@@ -853,7 +901,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
};
c_element_op
,
p_dropout
,
seed
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -872,7 +922,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
p_b_vec
,
...
...
@@ -886,7 +938,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
c_element_op
,
p_dropout
,
seed
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
100755 → 100644
View file @
9fe6407e
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -15,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace
ck
{
...
...
@@ -357,7 +359,10 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
...
@@ -376,7 +381,9 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
const
C0MatrixMask
&
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
ck
::
philox
ph
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -721,6 +728,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
decltype
(
thread_slice_desc_m_n
)
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
...
...
@@ -862,6 +871,15 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
blockwise_softmax
.
Run
(
acc_thread_buf
,
workspace_buf
);
if
constexpr
(
IsDropout
)
// dropout
{
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
p_dropout_in_16bits
,
ph
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
}
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment