Commit 2baf0613 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code: add multiA into example

parent b164b0ef
...@@ -32,6 +32,4 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -32,6 +32,4 @@ if(USE_BITINT_EXTENSION_INT4)
endif() endif()
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp)
if(result EQUAL 0) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
endif()
...@@ -33,8 +33,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -33,8 +33,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using A0DataType = F16; using A0DataType = F16;
using A1DataType = F32;
using AsDataType = ck::Tuple<A0DataType, A1DataType>;
using B0DataType = F16; using B0DataType = F16;
using AsDataType = ck::Tuple<A0DataType>;
using BsDataType = ck::Tuple<B0DataType>; using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
...@@ -43,14 +44,26 @@ using DsDataType = ck::Tuple<D0DataType>; ...@@ -43,14 +44,26 @@ using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32; using EDataType = F32;
using A0Layout = Row; using A0Layout = Row;
using A1Layout = Row;
using AsLayout = ck::Tuple<A0Layout, A1Layout>;
using B0Layout = Col; using B0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using BsLayout = ck::Tuple<B0Layout>; using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row; using ELayout = Row;
using AElementOp = PassThrough; struct AddScale
{
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const
{
a = scale * (a0 + a1);
}
float scale = 1.0;
};
using AElementOp = AddScale;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = Add; using CDEElementOp = Add;
...@@ -113,13 +126,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -113,13 +126,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
}; };
std::vector<Tensor<A0DataType>> a_tensors; std::vector<Tensor<A0DataType>> a0_tensors;
std::vector<Tensor<A1DataType>> a1_tensors;
std::vector<Tensor<B0DataType>> b_tensors; std::vector<Tensor<B0DataType>> b_tensors;
std::vector<Tensor<D0DataType>> d0_tensors; std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> c_host_tensors; std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors; std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count); a0_tensors.reserve(group_count);
a1_tensors.reserve(group_count);
b_tensors.reserve(group_count); b_tensors.reserve(group_count);
d0_tensors.reserve(group_count); d0_tensors.reserve(group_count);
c_host_tensors.reserve(group_count); c_host_tensors.reserve(group_count);
...@@ -127,10 +142,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -127,10 +142,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d0_tensors_device, std::vector<DeviceMemPtr> a0_tensors_device, a1_tensors_device, b_tensors_device,
c_tensors_device; d0_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count); a0_tensors_device.reserve(group_count);
a1_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count); d0_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count);
...@@ -140,8 +156,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -140,8 +156,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
sum_of_m += problem_size.Ms[i]; sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor( a0_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
a1_tensors.push_back(Tensor<A1DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{})));
b_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor( b_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
d0_tensors.push_back(Tensor<D0DataType>( d0_tensors.push_back(Tensor<D0DataType>(
...@@ -150,12 +168,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -150,12 +168,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor( c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
<< " c_m_n: " << c_device_tensors[i].mDesc << std::endl; << " c_m_n: " << c_device_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(A0DataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() +
sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() +
sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
...@@ -164,15 +183,18 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -164,15 +183,18 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{ {
case 0: break; case 0: break;
case 1: case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5}); a0_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_2<A1DataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
break; break;
case 2: case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
...@@ -180,16 +202,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -180,16 +202,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
using GroupedGemmKernelArgument = using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmMultiABDKernelArgument<1, 1, 1>; ck::tensor_operation::device::GroupedGemmMultiABDKernelArgument<2, 1, 1>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_; std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count); grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
a_tensors_device.emplace_back( a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
a1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>( b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
...@@ -199,9 +224,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -199,9 +224,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device.emplace_back( c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * a0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType)); sizeof(A0DataType));
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(),
a1_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A1DataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * b_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType)); sizeof(B0DataType));
...@@ -211,20 +240,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -211,20 +240,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_descs.push_back({sum_of_m, gemm_descs.push_back({sum_of_m,
problem_size.Ns[i], problem_size.Ns[i],
problem_size.Ks[i], problem_size.Ks[i],
{1}, {1, 1},
{problem_size.stride_Bs[i]}, {problem_size.stride_Bs[i]},
{0}, {0},
1}); 1});
grouped_gemm_kernel_args_.push_back( grouped_gemm_kernel_args_.push_back(
{std::array<const void*, 1>{a_tensors_device[i]->GetDeviceBuffer()}, {std::array<const void*, 2>{a0_tensors_device[i]->GetDeviceBuffer(),
a1_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, 1>{b_tensors_device[i]->GetDeviceBuffer()}, std::array<const void*, 1>{b_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()}, std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(), c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i], problem_size.Ms[i],
problem_size.Ns[i], problem_size.Ns[i],
problem_size.Ks[i], problem_size.Ks[i],
std::array<ck::index_t, 1>{problem_size.stride_As[i]}, std::array<ck::index_t, 2>{problem_size.stride_As[i], problem_size.stride_As[i]},
std::array<ck::index_t, 1>{problem_size.stride_Bs[i]}, std::array<ck::index_t, 1>{problem_size.stride_Bs[i]},
std::array<ck::index_t, 1>{0}, std::array<ck::index_t, 1>{0},
problem_size.stride_Cs[i]}); problem_size.stride_Cs[i]});
...@@ -237,7 +267,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -237,7 +267,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<std::array<const void*, 1>> p_As = {}; std::vector<std::array<const void*, 2>> p_As = {};
std::vector<std::array<const void*, 1>> p_Bs = {}; std::vector<std::array<const void*, 1>> p_Bs = {};
std::vector<std::array<const void*, 1>> p_Ds = {}; std::vector<std::array<const void*, 1>> p_Ds = {};
std::vector<void*> p_Cs = {}; std::vector<void*> p_Cs = {};
...@@ -281,16 +311,25 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -281,16 +311,25 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType, B0DataType,
EDataType, EDataType,
AccDataType, AccDataType,
AElementOp, PassThrough,
BElementOp, BElementOp,
PassThrough>; PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int k = 0; k < problem_size.Ks[i]; ++k)
{
a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k));
}
}
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() * c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType)); sizeof(EDataType));
...@@ -298,10 +337,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -298,10 +337,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i],
b_tensors[i], b_tensors[i],
c_host_tensors[i], c_host_tensors[i],
a_element_op, PassThrough{},
b_element_op, b_element_op,
PassThrough{}); PassThrough{});
......
...@@ -93,14 +93,6 @@ __global__ void ...@@ -93,14 +93,6 @@ __global__ void
typename GridwiseGemm::BsGridPointer p_bs_grid_; typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
// constexpr auto I0 = Number<0>{};
// using AsDataType = remove_cvref_t<decltype(p_as_grid_(I0))>;
// p_as_grid_(I0) = static_cast<AsDataType>(gemm_desc_ptr[group_id].p_a_grid);
// using BsDataType = remove_cvref_t<decltype(p_bs_grid_(I0))>;
// p_bs_grid_(I0) = static_cast<BsDataType>(gemm_desc_ptr[group_id].p_b_grid);
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>; using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>;
p_as_grid_(i) = static_cast<ADataType>(gemm_desc_ptr[group_id].p_as_grid[i]); p_as_grid_(i) = static_cast<ADataType>(gemm_desc_ptr[group_id].p_as_grid[i]);
...@@ -500,35 +492,32 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK ...@@ -500,35 +492,32 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
const index_t StrideE = gemm_descs[i].stride_C_; const index_t StrideE = gemm_descs[i].stride_C_;
static_for<0, NumATensor, 1>{}([&](auto j) { if(gemm_descs[i].stride_As_.size() != NumATensor)
if(gemm_descs[i].stride_As_.size() != NumATensor) {
{ throw std::runtime_error(
throw std::runtime_error( "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor");
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); }
}
StrideAs[j] = gemm_descs[i].stride_As_[j]; static_for<0, NumATensor, 1>{}(
}); [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; });
static_for<0, NumBTensor, 1>{}([&](auto j) { if(gemm_descs[i].stride_Bs_.size() != NumBTensor)
if(gemm_descs[i].stride_Bs_.size() != NumBTensor) {
{ throw std::runtime_error(
throw std::runtime_error( "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor");
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); }
}
StrideBs[j] = gemm_descs[i].stride_Bs_[j]; static_for<0, NumBTensor, 1>{}(
}); [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; });
static_for<0, NumDTensor, 1>{}([&](auto j) { if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
if(gemm_descs[i].stride_Ds_.size() != NumDTensor) {
{ throw std::runtime_error(
throw std::runtime_error( "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); }
}
StrideDs[j] = gemm_descs[i].stride_Ds_[j]; static_for<0, NumDTensor, 1>{}(
}); [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; });
const auto e_grid_desc_m_n = const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>( GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
...@@ -552,14 +541,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK ...@@ -552,14 +541,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
throw std::runtime_error("wrong! block_2_etile_map validation failed"); throw std::runtime_error("wrong! block_2_etile_map validation failed");
} }
// if(!GridwiseGemm::
// template CheckValidity<AsLayout, BsLayout, DsLayout, ELayout, GemmSpec>(
// AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
//{
// throw std::runtime_error(
//"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
//}
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_as_grid, p_as_grid,
p_bs_grid, p_bs_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment