Commit 13a0c55d authored by letaoqin's avatar letaoqin
Browse files

add grouped example

parent 9ca20e2c
...@@ -42,7 +42,7 @@ using B1DataType = F16; ...@@ -42,7 +42,7 @@ using B1DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = F16;
using Acc0BiasDataType = void; using Acc0BiasDataType = F16;
using Acc1BiasDataType = void; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
...@@ -66,72 +66,71 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -66,72 +66,71 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl< NumDimG,
NumDimG, NumDimM,
NumDimM, NumDimN,
NumDimN, NumDimK,
NumDimK, NumDimO,
NumDimO, ADataType,
ADataType, B0DataType,
B0DataType, B1DataType,
B1DataType, CDataType,
CDataType, Acc0BiasDataType,
Acc0BiasDataType, Acc1BiasDataType,
Acc1BiasDataType, AccDataType,
AccDataType, CShuffleDataType,
CShuffleDataType, AElementOp,
AElementOp, B0ElementOp,
B0ElementOp, Acc0ElementOp,
Acc0ElementOp, B1ElementOp,
B1ElementOp, CElementOp,
CElementOp, GemmSpec,
GemmSpec, TensorSpecA,
TensorSpecA, TensorSpecB0,
TensorSpecB0, TensorSpecB1,
TensorSpecB1, TensorSpecC,
TensorSpecC, 1,
1, 256,
256, 128, // MPerBlock
128, // MPerBlock 128, // NPerBlock
128, // NPerBlock 32, // KPerBlock
32, // KPerBlock 64, // Gemm1NPerBlock
64, // Gemm1NPerBlock 32, // Gemm1KPerBlock
32, // Gemm1KPerBlock 8, // AK1
8, // AK1 8, // BK1
8, // BK1 2, // B1K1
2, // B1K1 32, // MPerXDL
32, // MPerXDL 32, // NPerXDL
32, // NPerXDL 1, // MXdlPerWave
1, // MXdlPerWave 4, // NXdlPerWave
4, // NXdlPerWave 2, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer
S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, 2,
2, 8,
8, 8,
8, true,
true, S<4, 64, 1>, // BBlockTransfer
S<4, 64, 1>, // BBlockTransfer S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, 2,
2, 8,
8, 8,
8, true,
true, S<16, 16, 1>, // B1BlockTransfer
S<16, 16, 1>, // B1BlockTransfer S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, 1,
1, 4,
4, 2,
2, false,
false, 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
...@@ -194,7 +194,7 @@ int run(int argc, char* argv[]) ...@@ -194,7 +194,7 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -46,18 +46,21 @@ int run(int argc, char* argv[]) ...@@ -46,18 +46,21 @@ int run(int argc, char* argv[])
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
std::vector<const void*> p_a; std::vector<const void*> p_a;
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_d0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<Acc0BiasDataType>> d0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors; std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors; std::vector<Tensor<CDataType>> c_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device; std::vector<DeviceMemPtr> a_tensors_device;
std::vector<DeviceMemPtr> b0_tensors_device; std::vector<DeviceMemPtr> b0_tensors_device;
std::vector<DeviceMemPtr> d0_tensors_device;
std::vector<DeviceMemPtr> b1_tensors_device; std::vector<DeviceMemPtr> b1_tensors_device;
std::vector<DeviceMemPtr> c_tensors_device; std::vector<DeviceMemPtr> c_tensors_device;
...@@ -99,6 +102,12 @@ int run(int argc, char* argv[]) ...@@ -99,6 +102,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // d0 layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // d0 layout [G0, G1, M, N]
problem_descs.push_back({a_gs_ms_ks_lengths, problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
...@@ -107,22 +116,24 @@ int run(int argc, char* argv[]) ...@@ -107,22 +116,24 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
{}, // acc0_biases_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
{}, // acc0_biases_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths {}, // acc1_bias_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides {}}); // acc1_bias_gs_ms_os_strides
// C_m_o = A_m_k * B0_k_n * B1_n_o // C_m_o = (A_m_k * B0_k_n + bias) * B1_n_o
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<Acc0BiasDataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
int Batch = G0 * G1; int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte +=
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
Batch; sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
Batch;
if(i < 4) if(i < 4)
{ {
...@@ -138,26 +149,31 @@ int run(int argc, char* argv[]) ...@@ -138,26 +149,31 @@ int run(int argc, char* argv[])
case 1: case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_Diagonal<Acc0BiasDataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
a_tensors.push_back(a_gs_ms_ks); a_tensors.push_back(a_gs_ms_ks);
b0_tensors.push_back(b0_gs_ns_ks); b0_tensors.push_back(b0_gs_ns_ks);
d0_tensors.push_back(d0_gs_ms_ns);
b1_tensors.push_back(b1_gs_os_ns); b1_tensors.push_back(b1_gs_os_ns);
c_tensors.push_back(c_gs_ms_os_device_result); c_tensors.push_back(c_gs_ms_os_device_result);
...@@ -165,6 +181,8 @@ int run(int argc, char* argv[]) ...@@ -165,6 +181,8 @@ int run(int argc, char* argv[])
sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize())); sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()));
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>( b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize())); sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()));
d0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize()));
b1_tensors_device.emplace_back(std::make_unique<DeviceMem>( b1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize())); sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -172,10 +190,12 @@ int run(int argc, char* argv[]) ...@@ -172,10 +190,12 @@ int run(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data()); a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data());
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
d0_tensors_device[i]->ToDevice(d0_gs_ms_ns.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_d0.push_back(d0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
} }
...@@ -193,8 +213,8 @@ int run(int argc, char* argv[]) ...@@ -193,8 +213,8 @@ int run(int argc, char* argv[])
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
{}, // p_acc0_biases p_d0, // p_acc0_bias
{}, // p_acc1_biases {}, // p_acc1_bias
problem_descs, problem_descs,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
...@@ -240,6 +260,7 @@ int run(int argc, char* argv[]) ...@@ -240,6 +260,7 @@ int run(int argc, char* argv[])
const auto& a_gs_ms_ks = a_tensors[i]; const auto& a_gs_ms_ks = a_tensors[i];
const auto& b0_gs_ns_ks = b0_tensors[i]; const auto& b0_gs_ns_ks = b0_tensors[i];
const auto& d0_gs_ms_ns = d0_tensors[i];
const auto& b1_gs_os_ns = b1_tensors[i]; const auto& b1_gs_os_ns = b1_tensors[i];
auto& c_gs_ms_os_device_result = c_tensors[i]; auto& c_gs_ms_os_device_result = c_tensors[i];
auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
...@@ -248,6 +269,7 @@ int run(int argc, char* argv[]) ...@@ -248,6 +269,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_g_m_k({G0 * G1, M, K}); Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
...@@ -261,6 +283,9 @@ int run(int argc, char* argv[]) ...@@ -261,6 +283,9 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
...@@ -273,6 +298,11 @@ int run(int argc, char* argv[]) ...@@ -273,6 +298,11 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d0_g_m_n(idx));
});
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......
...@@ -924,7 +924,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -924,7 +924,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle" str << "DeviceBatchedMultiheadAttentionForward_Xdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -552,8 +552,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -552,8 +552,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths; tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides; tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment