Commit 6cc7d0de authored by danyao12's avatar danyao12
Browse files

rename device ops

parent 38f48480
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1(
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -956,16 +956,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1209,7 +1210,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -948,16 +948,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1200,7 +1201,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1(
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -251,7 +251,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -263,7 +263,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -961,16 +961,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1207,7 +1208,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -258,7 +258,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -270,7 +270,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -968,16 +968,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1219,7 +1220,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -250,7 +250,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
: public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -290,7 +290,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle;
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
......@@ -1123,7 +1123,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -256,7 +256,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -296,7 +296,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle;
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -1145,7 +1145,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment