Commit e0d6326b authored by letaoqin's avatar letaoqin
Browse files

change interface

parent b60595f9
...@@ -71,8 +71,8 @@ using ShuffleDataType = F32; ...@@ -71,8 +71,8 @@ using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using DDataType = F16; using DDataType = F16;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = DDataType;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -529,9 +529,8 @@ int run(int argc, char* argv[]) ...@@ -529,9 +529,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{ static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases;
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc1_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -543,12 +542,10 @@ int run(int argc, char* argv[]) ...@@ -543,12 +542,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{ d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
std::array<std::vector<ck::index_t>, 1>{ {}, // acc1_biases_gs_ms_os_lengths,
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides {}, // acc1_biases_gs_ms_os_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -566,41 +563,41 @@ int run(int argc, char* argv[]) ...@@ -566,41 +563,41 @@ int run(int argc, char* argv[])
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
} }
// not need output z matrix // not need output z matrix
auto argument = gemm.MakeArgument( auto argument =
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
k_gs_ns_ks_strides, k_gs_ns_ks_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
v_gs_os_ns_lengths, v_gs_os_ns_lengths,
v_gs_os_ns_strides, v_gs_os_ns_strides,
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // acc1_biases_gs_ms_os_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // acc1_biases_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
qgrad_device_buf.SetZero(); qgrad_device_buf.SetZero();
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -291,11 +291,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -291,11 +291,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(std::is_void<D1DataType>::value, "Acc1 Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
...@@ -702,43 +702,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -702,43 +702,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const InputDataType* p_a_grid,
const InputDataType* p_a_grid, const InputDataType* p_b_grid,
const InputDataType* p_b_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, const InputDataType* p_b1_grid,
const InputDataType* p_b1_grid, const InputDataType* p_c_grid, // for dS
const InputDataType* p_c_grid, // for dS const LSEDataType* p_lse_grid,
const LSEDataType* p_lse_grid, const InputDataType* p_ygrad_grid,
const InputDataType* p_ygrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_vgrad_grid,
OutputDataType* p_vgrad_grid, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D1DataType* p_acc1_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op, BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op, float p_drop,
float p_drop, std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
...@@ -1108,43 +1107,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1108,43 +1107,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const InputDataType* p_a, MakeArgument(const InputDataType* p_a,
const InputDataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const InputDataType* p_b1, const InputDataType* p_b1,
const InputDataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
const InputDataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1197,8 +1196,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1197,8 +1196,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1210,11 +1209,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1210,11 +1209,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1224,40 +1223,41 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1224,40 +1223,41 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a), return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_b), static_cast<const InputDataType*>(p_a),
static_cast<ZDataType*>(p_z), static_cast<const InputDataType*>(p_b),
static_cast<const InputDataType*>(p_b1), static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_c), static_cast<const InputDataType*>(p_b1),
static_cast<const LSEDataType*>(p_lse), static_cast<const InputDataType*>(p_c),
static_cast<const InputDataType*>(p_ygrad_grid), static_cast<const LSEDataType*>(p_lse),
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
p_acc0_biases, // cast in struct Argument static_cast<OutputDataType*>(p_vgrad_grid),
p_acc1_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
a_gs_ms_ks_lengths, static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
a_gs_ms_ks_strides, a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths, a_gs_ms_ks_strides,
b_gs_ns_ks_strides, b_gs_ns_ks_lengths,
z_gs_ms_ns_lengths, b_gs_ns_ks_strides,
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
lse_gs_ms_lengths, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_strides, acc1_biases_gs_ms_gemm1ns_lengths,
a_element_op, acc1_biases_gs_ms_gemm1ns_strides,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
p_drop, c_element_op,
seeds); p_drop,
seeds);
} }
// polymorphic // polymorphic
......
...@@ -299,11 +299,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -299,11 +299,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc1Bias == 0, "Bias addition is unimplemented"); static_assert(std::is_void<D1DataType>::value, "Acc1 Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
...@@ -718,43 +718,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -718,43 +718,42 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const InputDataType* p_a_grid,
const InputDataType* p_a_grid, const InputDataType* p_b_grid,
const InputDataType* p_b_grid, ZDataType* p_z_grid,
ZDataType* p_z_grid, const InputDataType* p_b1_grid,
const InputDataType* p_b1_grid, const InputDataType* p_c_grid, // for dS
const InputDataType* p_c_grid, // for dS const LSEDataType* p_lse_grid,
const LSEDataType* p_lse_grid, const InputDataType* p_ygrad_grid,
const InputDataType* p_ygrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_vgrad_grid,
OutputDataType* p_vgrad_grid, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D1DataType* p_acc1_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::vector<ck::index_t>&
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op,
AElementwiseOperation a_element_op, BElementwiseOperation b_element_op,
BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op,
AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op,
B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op,
CElementwiseOperation c_element_op, float p_drop,
float p_drop, std::tuple<unsigned long long, unsigned long long> seeds)
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
...@@ -1143,43 +1142,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1143,43 +1142,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const InputDataType* p_a, MakeArgument(const InputDataType* p_a,
const InputDataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const InputDataType* p_b1, const InputDataType* p_b1,
const InputDataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
const InputDataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1232,8 +1231,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1232,8 +1231,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1245,11 +1244,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1245,11 +1244,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1259,40 +1258,41 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1259,40 +1258,41 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a), return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_b), static_cast<const InputDataType*>(p_a),
static_cast<ZDataType*>(p_z), static_cast<const InputDataType*>(p_b),
static_cast<const InputDataType*>(p_b1), static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_c), static_cast<const InputDataType*>(p_b1),
static_cast<const LSEDataType*>(p_lse), static_cast<const InputDataType*>(p_c),
static_cast<const InputDataType*>(p_ygrad_grid), static_cast<const LSEDataType*>(p_lse),
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
p_acc0_biases, // cast in struct Argument static_cast<OutputDataType*>(p_vgrad_grid),
p_acc1_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
a_gs_ms_ks_lengths, static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
a_gs_ms_ks_strides, a_gs_ms_ks_lengths,
b_gs_ns_ks_lengths, a_gs_ms_ks_strides,
b_gs_ns_ks_strides, b_gs_ns_ks_lengths,
z_gs_ms_ns_lengths, b_gs_ns_ks_strides,
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
lse_gs_ms_lengths, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_strides, acc1_biases_gs_ms_gemm1ns_lengths,
a_element_op, acc1_biases_gs_ms_gemm1ns_strides,
b_element_op, a_element_op,
acc_element_op, b_element_op,
b1_element_op, acc_element_op,
c_element_op, b1_element_op,
p_drop, c_element_op,
seeds); p_drop,
seeds);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment