Commit 2baf0613 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code: add multiA into example

parent b164b0ef
......@@ -32,6 +32,4 @@ if(USE_BITINT_EXTENSION_INT4)
endif()
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
endif()
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
......@@ -33,8 +33,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using A0DataType = F16;
using A1DataType = F32;
using AsDataType = ck::Tuple<A0DataType, A1DataType>;
using B0DataType = F16;
using AsDataType = ck::Tuple<A0DataType>;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
......@@ -43,14 +44,26 @@ using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32;
using A0Layout = Row;
using A1Layout = Row;
using AsLayout = ck::Tuple<A0Layout, A1Layout>;
using B0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using AElementOp = PassThrough;
struct AddScale
{
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const
{
a = scale * (a0 + a1);
}
float scale = 1.0;
};
using AElementOp = AddScale;
using BElementOp = PassThrough;
using CDEElementOp = Add;
......@@ -113,13 +126,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
};
std::vector<Tensor<A0DataType>> a_tensors;
std::vector<Tensor<A0DataType>> a0_tensors;
std::vector<Tensor<A1DataType>> a1_tensors;
std::vector<Tensor<B0DataType>> b_tensors;
std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
a0_tensors.reserve(group_count);
a1_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d0_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
......@@ -127,10 +142,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d0_tensors_device,
c_tensors_device;
std::vector<DeviceMemPtr> a0_tensors_device, a1_tensors_device, b_tensors_device,
d0_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
a0_tensors_device.reserve(group_count);
a1_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
......@@ -140,8 +156,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
a0_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
a1_tensors.push_back(Tensor<A1DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{})));
b_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
d0_tensors.push_back(Tensor<D0DataType>(
......@@ -150,12 +168,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
<< " c_m_n: " << c_device_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(A0DataType) * a_tensors[i].mDesc.GetElementSize() +
num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() +
sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() +
sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
......@@ -164,15 +183,18 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
a0_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_2<A1DataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a0_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
......@@ -180,16 +202,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmMultiABDKernelArgument<1, 1, 1>;
ck::tensor_operation::device::GroupedGemmMultiABDKernelArgument<2, 1, 1>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
a1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
......@@ -199,9 +224,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() *
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
a0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType));
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(),
a1_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A1DataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType));
......@@ -211,20 +240,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
{1},
{1, 1},
{problem_size.stride_Bs[i]},
{0},
1});
grouped_gemm_kernel_args_.push_back(
{std::array<const void*, 1>{a_tensors_device[i]->GetDeviceBuffer()},
{std::array<const void*, 2>{a0_tensors_device[i]->GetDeviceBuffer(),
a1_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, 1>{b_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
std::array<ck::index_t, 1>{problem_size.stride_As[i]},
std::array<ck::index_t, 2>{problem_size.stride_As[i], problem_size.stride_As[i]},
std::array<ck::index_t, 1>{problem_size.stride_Bs[i]},
std::array<ck::index_t, 1>{0},
problem_size.stride_Cs[i]});
......@@ -237,7 +267,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<std::array<const void*, 1>> p_As = {};
std::vector<std::array<const void*, 2>> p_As = {};
std::vector<std::array<const void*, 1>> p_Bs = {};
std::vector<std::array<const void*, 1>> p_Ds = {};
std::vector<void*> p_Cs = {};
......@@ -281,16 +311,25 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
EDataType,
AccDataType,
AElementOp,
PassThrough,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int k = 0; k < problem_size.Ks[i]; ++k)
{
a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k));
}
}
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
......@@ -298,10 +337,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
PassThrough{},
b_element_op,
PassThrough{});
......
......@@ -93,14 +93,6 @@ __global__ void
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
// constexpr auto I0 = Number<0>{};
// using AsDataType = remove_cvref_t<decltype(p_as_grid_(I0))>;
// p_as_grid_(I0) = static_cast<AsDataType>(gemm_desc_ptr[group_id].p_a_grid);
// using BsDataType = remove_cvref_t<decltype(p_bs_grid_(I0))>;
// p_bs_grid_(I0) = static_cast<BsDataType>(gemm_desc_ptr[group_id].p_b_grid);
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>;
p_as_grid_(i) = static_cast<ADataType>(gemm_desc_ptr[group_id].p_as_grid[i]);
......@@ -500,35 +492,32 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
const index_t StrideE = gemm_descs[i].stride_C_;
static_for<0, NumATensor, 1>{}([&](auto j) {
if(gemm_descs[i].stride_As_.size() != NumATensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor");
}
StrideAs[j] = gemm_descs[i].stride_As_[j];
});
static_for<0, NumATensor, 1>{}(
[&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; });
static_for<0, NumBTensor, 1>{}([&](auto j) {
if(gemm_descs[i].stride_Bs_.size() != NumBTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor");
}
StrideBs[j] = gemm_descs[i].stride_Bs_[j];
});
static_for<0, NumBTensor, 1>{}(
[&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; });
static_for<0, NumDTensor, 1>{}([&](auto j) {
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
});
static_for<0, NumDTensor, 1>{}(
[&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; });
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
......@@ -552,14 +541,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
throw std::runtime_error("wrong! block_2_etile_map validation failed");
}
// if(!GridwiseGemm::
// template CheckValidity<AsLayout, BsLayout, DsLayout, ELayout, GemmSpec>(
// AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
//{
// throw std::runtime_error(
//"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
//}
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_as_grid,
p_bs_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment