Unverified Commit 95d76f67 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #845 from ROCmSoftwarePlatform/mha-train-develop-reduce-interface

Mha train develop reduce interface
parents 102c9661 f158f4d4
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -79,8 +79,8 @@ using AccDataType = F32; ...@@ -79,8 +79,8 @@ using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
...@@ -534,8 +534,8 @@ int run(int argc, char* argv[]) ...@@ -534,8 +534,8 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -594,8 +594,8 @@ int run(int argc, char* argv[]) ...@@ -594,8 +594,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -52,8 +52,8 @@ using CShuffleDataType = F32; ...@@ -52,8 +52,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -78,8 +78,8 @@ using AccDataType = F32; ...@@ -78,8 +78,8 @@ using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = void;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -177,8 +177,8 @@ int run(int argc, char* argv[]) ...@@ -177,8 +177,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
......
...@@ -287,8 +287,8 @@ int run(int argc, char* argv[]) ...@@ -287,8 +287,8 @@ int run(int argc, char* argv[])
p_c, p_c,
p_z, p_z,
p_lse, p_lse,
{}, // p_acc0_biases nullptr, // p_acc0_biases
{}, // p_acc1_biases nullptr, // p_acc1_biases
problem_descs, problem_descs,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
......
...@@ -53,8 +53,8 @@ using CDataType = DataType; ...@@ -53,8 +53,8 @@ using CDataType = DataType;
using DDataType = F16; using DDataType = F16;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = DDataType;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -53,8 +53,8 @@ using CShuffleDataType = F32; ...@@ -53,8 +53,8 @@ using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<DDataType>; using Acc0BiasDataType = DDataType;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimM = 1;
......
...@@ -137,7 +137,7 @@ int run(int argc, char* argv[]) ...@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 2}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
...@@ -190,8 +190,8 @@ int run(int argc, char* argv[]) ...@@ -190,8 +190,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), //
{}, // std::array<void*, 1> p_acc1_biases; nullptr,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
...@@ -203,10 +203,10 @@ int run(int argc, char* argv[]) ...@@ -203,10 +203,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::vector<ck::index_t>
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::vector<ck::index_t>
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
...@@ -230,7 +230,7 @@ int run(int argc, char* argv[]) ...@@ -230,7 +230,7 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * Acc0BiasDataType::Size()) * sizeof(DDataType) * M * N * std::is_void<DDataType>::value?1:0) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -250,9 +250,8 @@ int run(int argc, char* argv[]) ...@@ -250,9 +250,8 @@ int run(int argc, char* argv[])
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{ static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()),
d_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases; nullptr,
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
...@@ -264,12 +263,10 @@ int run(int argc, char* argv[]) ...@@ -264,12 +263,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
std::array<std::vector<ck::index_t>, 1>{ d_gs_ms_ns_lengths,
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_strides,
std::array<std::vector<ck::index_t>, 1>{ {},
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides {},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
......
...@@ -57,7 +57,7 @@ int run(int argc, char* argv[]) ...@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<std::vector<const void*>> p_d; std::vector<const void*> p_d;
std::vector<void*> p_z; // for result verification std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
...@@ -147,10 +147,8 @@ int run(int argc, char* argv[]) ...@@ -147,10 +147,8 @@ int run(int argc, char* argv[])
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
std::vector<std::vector<ck::index_t>>{ d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
std::vector<std::vector<ck::index_t>>{
d_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides {}}); // acc1_biases_gs_ms_os_strides
...@@ -167,7 +165,7 @@ int run(int argc, char* argv[]) ...@@ -167,7 +165,7 @@ int run(int argc, char* argv[])
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * (Acc0BiasDataType::Size() ? 0 : 1)) * sizeof(DDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch; Batch;
if(i < 4) if(i < 4)
...@@ -244,9 +242,7 @@ int run(int argc, char* argv[]) ...@@ -244,9 +242,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_d.push_back({d_tensors_device[i]->GetDeviceBuffer()}); p_d.push_back(d_tensors_device[i]->GetDeviceBuffer());
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr); p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
......
...@@ -87,9 +87,6 @@ template <index_t NumDimG, ...@@ -87,9 +87,6 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceBatchedMultiheadAttentionForward : public BaseOperator struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{ {
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b0, const void* p_b0,
...@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -97,8 +94,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const void* p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const void* p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -110,11 +107,11 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -110,11 +107,11 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::vector<index_t>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::vector<index_t>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias> const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias> const std::vector<index_t>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
......
...@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec, std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<const void*> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
......
...@@ -25,7 +25,7 @@ namespace ck { ...@@ -25,7 +25,7 @@ namespace ck {
* *
*/ */
template <typename FloatAB, template <typename FloatAB,
typename D0sDataType, typename D0DataType,
typename ZDataType, typename ZDataType,
typename FloatGemm, typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
...@@ -40,7 +40,7 @@ template <typename FloatAB, ...@@ -40,7 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0sGridDesc_M_N, typename D0GridDesc_M_N,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename ZGridDesc_M_N, typename ZGridDesc_M_N,
...@@ -102,7 +102,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -102,7 +102,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"); "D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -441,20 +440,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -441,20 +440,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_m_n); c_grid_desc_m_n);
} }
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0DataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// D0 desc for source in blockwise copy // D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{ {
const auto M = d0_grid_desc_m_n.GetLength(I0); const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1); const auto N = d0_grid_desc_m_n.GetLength(I1);
...@@ -472,20 +460,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -472,20 +460,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
// D0s desc for source in blockwise copy using D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
__host__ __device__ static constexpr auto MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0GridDesc_M_N{}))>;
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
...@@ -544,7 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -544,7 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename C0MatrixMask> typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
D0sGridPointer p_d0s_grid, const D0DataType* __restrict__ p_d0_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -557,8 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -557,8 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -985,49 +961,36 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -985,49 +961,36 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n3, // NInputNum n3, // NInputNum
n4)); // RegisterNum n4)); // RegisterNum
auto d0s_threadwise_copy = generate_tuple( auto d0_threadwise_copy =
[&](auto i) { ThreadwiseTensorSliceTransfer_v2<D0DataType,
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>; D0DataType,
return ThreadwiseTensorSliceTransfer_v2< decltype(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
D0DataType, decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
D0DataType, Sequence<I1, // MBlockId
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), I1, // NBlockID
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), m0, // MRepeat
Sequence<I1, // MBlockId n0, // NRepeat
I1, // NBlockID m1, // MWaveId
m0, // MRepeat n1, // NWaveId
n0, // NRepeat m2, // MPerXdl
m1, // MWaveId n2, // NGroupNum
n1, // NWaveId n3, // NInputNum
m2, // MPerXdl n4>,
n2, // NGroupNum Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
n3, // NInputNum 9,
n4>, D0BlockTransferSrcScalarPerVector,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 1,
9, false>(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
D0BlockTransferSrcScalarPerVector, make_multi_index(block_work_idx_m, // MBlockId
1, 0, // NBlockId
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], 0, // mrepeat
make_multi_index(block_work_idx_m, // MBlockId 0, // nrepeat
0, // NBlockId wave_id[I0], // MWaveId
0, // mrepeat wave_id[I1], // NWaveId
0, // nrepeat wave_m_n_id[I1], // MPerXdl
wave_id[I0], // MWaveId 0, // group
wave_id[I1], // NWaveId wave_m_n_id[I0], // NInputIndex
wave_m_n_id[I1], // MPerXdl 0)); // register number
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
},
Number<NumD0Tensor>{});
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
...@@ -1325,9 +1288,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1325,9 +1288,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
static_for<0, NumD0Tensor, 1>{}([&](auto i) { if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// get register // get register
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
D0DataType, D0DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
...@@ -1335,20 +1300,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1335,20 +1300,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_thread_buf; d0_thread_buf;
// load data from global // load data from global
d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0_threadwise_copy.Run(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0s_grid_buf[i], d0_grid_buf,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0_thread_buf); d0_thread_buf);
// acc add bias // acc add bias
static_for<0, m0 * n0 * n2 * n4, 1>{}( static_for<0, m0 * n0 * n2 * n4, 1>{}(
[&](auto j) { acc_thread_buf(j) += d0_thread_buf[j]; }); [&](auto i) { acc_thread_buf(i) += d0_thread_buf[i]; });
d0s_threadwise_copy(i).MoveSrcSliceWindow( d0_threadwise_copy.MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}); }
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment