Commit e0d6326b authored by letaoqin's avatar letaoqin
Browse files

change interface

parent b60595f9
...@@ -71,8 +71,8 @@ using ShuffleDataType = F32; ...@@ -71,8 +71,8 @@ using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using DDataType = F16; using DDataType = F16;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = DDataType;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -529,9 +529,8 @@ int run(int argc, char* argv[]) ...@@ -529,9 +529,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{ static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases;
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc1_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -543,12 +542,10 @@ int run(int argc, char* argv[]) ...@@ -543,12 +542,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{ d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
std::array<std::vector<ck::index_t>, 1>{ {}, // acc1_biases_gs_ms_os_lengths,
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides {}, // acc1_biases_gs_ms_os_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -566,8 +563,8 @@ int run(int argc, char* argv[]) ...@@ -566,8 +563,8 @@ int run(int argc, char* argv[])
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
} }
// not need output z matrix // not need output z matrix
auto argument = gemm.MakeArgument( auto argument =
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
...@@ -577,8 +574,8 @@ int run(int argc, char* argv[]) ...@@ -577,8 +574,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -590,10 +587,10 @@ int run(int argc, char* argv[]) ...@@ -590,10 +587,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // acc1_biases_gs_ms_os_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // acc1_biases_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
......
...@@ -291,11 +291,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -291,11 +291,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(std::is_void<D1DataType>::value, "Acc1 Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
...@@ -702,8 +702,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -702,8 +702,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const InputDataType* p_a_grid,
const InputDataType* p_a_grid,
const InputDataType* p_b_grid, const InputDataType* p_b_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
const InputDataType* p_b1_grid, const InputDataType* p_b1_grid,
...@@ -713,8 +712,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -713,8 +712,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -726,11 +725,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -726,11 +725,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1108,8 +1107,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1108,8 +1107,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const InputDataType* p_a, MakeArgument(const InputDataType* p_a,
const InputDataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const InputDataType* p_b1, const InputDataType* p_b1,
...@@ -1119,8 +1118,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1119,8 +1118,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1132,11 +1131,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1132,11 +1131,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1197,8 +1196,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1197,8 +1196,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1210,11 +1209,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1210,11 +1209,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1224,7 +1223,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1224,7 +1223,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a), return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b), static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_b1), static_cast<const InputDataType*>(p_b1),
...@@ -1234,8 +1234,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1234,8 +1234,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
p_acc1_biases, // cast in struct Argument static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -299,11 +299,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -299,11 +299,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); using D0DataType = Acc0BiasDataType;
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); using D1DataType = Acc1BiasDataType;
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc1Bias == 0, "Bias addition is unimplemented"); static_assert(std::is_void<D1DataType>::value, "Acc1 Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
...@@ -718,8 +718,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -718,8 +718,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(const InputDataType* p_a_grid,
const InputDataType* p_a_grid,
const InputDataType* p_b_grid, const InputDataType* p_b_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
const InputDataType* p_b1_grid, const InputDataType* p_b1_grid,
...@@ -729,8 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -729,8 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -742,11 +741,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -742,11 +741,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1143,8 +1142,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1143,8 +1142,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
const InputDataType* p_a, MakeArgument(const InputDataType* p_a,
const InputDataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const InputDataType* p_b1, const InputDataType* p_b1,
...@@ -1154,8 +1153,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1154,8 +1153,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const D0DataType* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const D1DataType* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1167,11 +1166,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1167,11 +1166,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1232,8 +1231,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1232,8 +1231,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
void* p_vgrad_grid, void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1245,11 +1244,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1245,11 +1244,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::vector<ck::index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1259,7 +1258,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1259,7 +1258,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a), return std::make_unique<Argument>(
static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b), static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_b1), static_cast<const InputDataType*>(p_b1),
...@@ -1269,8 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1269,8 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument static_cast<const D0DataType*>(p_acc0_biases), // cast in struct Argument
p_acc1_biases, // cast in struct Argument static_cast<const D1DataType*>(p_acc1_biases), // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment