Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f196d88b
Commit
f196d88b
authored
Jan 14, 2023
by
guangzlu
Browse files
added dropout into batched_gemm_softmax_gemm
parent
9fe6407e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
33 deletions
+87
-33
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_train.inc
...emm/run_batched_gemm_scale_softmax_gemm_permute_train.inc
+2
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+82
-31
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_train.inc
View file @
f196d88b
...
@@ -177,7 +177,8 @@ int run(int argc, char* argv[])
...
@@ -177,7 +177,8 @@ int run(int argc, char* argv[])
b0_element_op
,
b0_element_op
,
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
);
c_element_op
,
0
);
// dropout ratio
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
f196d88b
...
@@ -115,7 +115,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
...
@@ -115,7 +115,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B0ElementwiseOperation
b0_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
100755 → 100644
View file @
f196d88b
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <sstream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
...
@@ -39,7 +40,8 @@ template <typename GridwiseGemm,
...
@@ -39,7 +40,8 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
IsDropout
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -64,7 +66,9 @@ __global__ void
...
@@ -64,7 +66,9 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
unsigned
long
long
seed
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -83,24 +87,30 @@ __global__ void
...
@@ -83,24 +87,30 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
const
index_t
block_id
=
get_block_1d_id
();
p_b_grid
+
b_batch_offset
,
ck
::
philox
ph
(
seed
,
0
,
block_id
*
4
);
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_lse_grid
+
lse_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_shared
,
p_b_grid
+
b_batch_offset
,
a_element_op
,
p_b1_grid
+
b1_batch_offset
,
b_element_op
,
p_c_grid
+
c_batch_offset
,
acc_element_op
,
p_lse_grid
+
lse_batch_offset
,
b1_element_op
,
p_shared
,
c_element_op
,
a_element_op
,
a_grid_desc_ak0_m_ak1
,
b_element_op
,
b_grid_desc_bk0_n_bk1
,
acc_element_op
,
b1_grid_desc_bk0_n_bk1
,
b1_element_op
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_element_op
,
lse_grid_desc_m
,
a_grid_desc_ak0_m_ak1
,
block_2_ctile_map
,
b_grid_desc_bk0_n_bk1
,
c0_matrix_mask
);
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
p_dropout_in_16bits
,
ph
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -463,7 +473,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -463,7 +473,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
float
p_dropout
,
unsigned
long
long
seed
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
...
@@ -512,7 +524,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -512,7 +524,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
seed_
(
seed
)
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc0_biases
;
...
@@ -532,6 +545,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -532,6 +545,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
}
}
void
Print
()
const
void
Print
()
const
...
@@ -592,6 +609,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -592,6 +609,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
ushort
p_dropout_in_16bits_
;
unsigned
long
long
seed_
;
bool
is_dropout_
;
};
};
// Invoker
// Invoker
...
@@ -615,7 +637,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -615,7 +637,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
...
@@ -634,7 +656,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -634,7 +656,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
is_dropout_
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -659,18 +682,38 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -659,18 +682,38 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
);
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_16bits_
,
arg
.
seed_
);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
return
ave_time
;
return
ave_time
;
...
@@ -793,7 +836,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -793,7 +836,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -819,7 +864,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -819,7 +864,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
};
c_element_op
,
p_dropout
,
seed
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -853,7 +900,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -853,7 +900,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
,
float
p_dropout
,
const
unsigned
long
long
seed
=
0
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -879,7 +928,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -879,7 +928,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
);
c_element_op
,
p_dropout
,
seed
);
}
}
// polymorphic
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment