Commit 13a0c55d authored by letaoqin's avatar letaoqin
Browse files

add grouped example

parent 9ca20e2c
......@@ -42,7 +42,7 @@ using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using Acc0BiasDataType = void;
using Acc0BiasDataType = F16;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
......@@ -66,72 +66,71 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
......@@ -194,7 +194,7 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
......@@ -46,18 +46,21 @@ int run(int argc, char* argv[])
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
std::vector<const void*> p_a;
std::vector<const void*> p_b0;
std::vector<const void*> p_d0;
std::vector<const void*> p_b1;
std::vector<void*> p_c;
std::vector<std::vector<int>> g0_g1_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<Acc0BiasDataType>> d0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device;
std::vector<DeviceMemPtr> b0_tensors_device;
std::vector<DeviceMemPtr> d0_tensors_device;
std::vector<DeviceMemPtr> b1_tensors_device;
std::vector<DeviceMemPtr> c_tensors_device;
......@@ -99,6 +102,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // d0 layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // d0 layout [G0, G1, M, N]
problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
......@@ -107,22 +116,24 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
{}, // acc0_biases_gs_ms_ns_lengths
{}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}}); // acc1_biases_gs_ms_os_strides
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths
{}}); // acc1_bias_gs_ms_os_strides
// C_m_o = A_m_k * B0_k_n * B1_n_o
// C_m_o = (A_m_k * B0_k_n + bias) * B1_n_o
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<Acc0BiasDataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
int Batch = G0 * G1;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
Batch;
num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
Batch;
if(i < 4)
{
......@@ -138,26 +149,31 @@ int run(int argc, char* argv[])
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_Diagonal<Acc0BiasDataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
a_tensors.push_back(a_gs_ms_ks);
b0_tensors.push_back(b0_gs_ns_ks);
d0_tensors.push_back(d0_gs_ms_ns);
b1_tensors.push_back(b1_gs_os_ns);
c_tensors.push_back(c_gs_ms_os_device_result);
......@@ -165,6 +181,8 @@ int run(int argc, char* argv[])
sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()));
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()));
d0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize()));
b1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
......@@ -172,10 +190,12 @@ int run(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data());
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
d0_tensors_device[i]->ToDevice(d0_gs_ms_ns.mData.data());
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_d0.push_back(d0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
......@@ -193,8 +213,8 @@ int run(int argc, char* argv[])
p_b0,
p_b1,
p_c,
{}, // p_acc0_biases
{}, // p_acc1_biases
p_d0, // p_acc0_bias
{}, // p_acc1_bias
problem_descs,
a_element_op,
b0_element_op,
......@@ -240,6 +260,7 @@ int run(int argc, char* argv[])
const auto& a_gs_ms_ks = a_tensors[i];
const auto& b0_gs_ns_ks = b0_tensors[i];
const auto& d0_gs_ms_ns = d0_tensors[i];
const auto& b1_gs_os_ns = b1_tensors[i];
auto& c_gs_ms_os_device_result = c_tensors[i];
auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
......@@ -248,6 +269,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
......@@ -261,6 +283,9 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
......@@ -273,6 +298,11 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) += ck::type_convert<AccDataType>(d0_g_m_n(idx));
});
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
......
......@@ -924,7 +924,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle"
str << "DeviceBatchedMultiheadAttentionForward_Xdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -552,8 +552,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment