Commit 52b80192 authored by guangzlu's avatar guangzlu
Browse files

added inherit relationship for bwd qloop

parent 172835a5
......@@ -122,6 +122,257 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionBackwardQloopV1 : public BaseOperator
{
using D0DataType = Acc0BiasDataType;
using D1DataType = Acc1BiasDataType;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionBackwardQloopV2 : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename DDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionBackwardQloopLightV1 : public BaseOperator
{
using D0DataType = Acc0BiasDataType;
using D1DataType = Acc1BiasDataType;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
void* p_d_grid,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename DDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionBackwardQloopLightV2 : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
void* p_d_grid,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const void* p_acc0_bias,
const void* p_acc1_bias,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
const std::vector<ck::index_t>&
acc1_bias_gs_ms_gemm1ns_strides, // acc1_bias_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -11,7 +11,7 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
......@@ -361,7 +361,24 @@ template <index_t NumDimG,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
: public DeviceBatchedMultiheadAttentionBackwardQloopLightV1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -1445,7 +1462,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_a),
......@@ -1486,7 +1503,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
......
......@@ -11,7 +11,7 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
......@@ -369,7 +369,24 @@ template <index_t NumDimG,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
: public DeviceBatchedMultiheadAttentionBackwardQloopLightV2<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -1477,7 +1494,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_a),
......@@ -1518,7 +1535,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
......
......@@ -11,7 +11,7 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
......@@ -306,7 +306,23 @@ template <index_t NumDimG,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
: public DeviceBatchedMultiheadAttentionBackwardQloopV1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -1301,7 +1317,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_a),
......@@ -1341,7 +1357,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
......
......@@ -11,7 +11,7 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
......@@ -314,7 +314,23 @@ template <index_t NumDimG,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
: public DeviceBatchedMultiheadAttentionBackwardQloopV2<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
......@@ -1334,7 +1350,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_a),
......@@ -1374,7 +1390,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment