Commit 426abafe authored by Jing Zhang's avatar Jing Zhang
Browse files

init desc

parent 88578483
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_transpose_xdl_fp16 grouped_gemm_transpose_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_grouped_gemm_transpose_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmTransposeXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmTransposeDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
int M = 1024;
int N = 1024;
int K = 1024;
gemm_descs.push_back({M, N, K, K, K, N});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_descs[i].M * gemm_descs[i].K * gemm_descs[i].N;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
}
}
return pass ? 0 : 1;
}
...@@ -81,11 +81,11 @@ int main(int argc, char* argv[]) ...@@ -81,11 +81,11 @@ int main(int argc, char* argv[])
int group_count = rand() % 16 + 1; int group_count = rand() % 16 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b; std::vector<const void*> p_a, p_b;
std::vector<void*> p_c; std::vector<void*> p_c;
gemm_shapes.reserve(group_count); gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
...@@ -93,7 +93,7 @@ int main(int argc, char* argv[]) ...@@ -93,7 +93,7 @@ int main(int argc, char* argv[])
int N = 128 + 128 * i; int N = 128 + 128 * i;
int K = 64 + 64 * i; int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N}); gemm_descs.push_back({M, N, K, K, K, N});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -111,7 +111,6 @@ int main(int argc, char* argv[]) ...@@ -111,7 +111,6 @@ int main(int argc, char* argv[])
}; };
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors; std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors; std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors; std::vector<Tensor<CDataType>> c_device_tensors;
...@@ -131,22 +130,22 @@ int main(int argc, char* argv[]) ...@@ -131,22 +130,22 @@ int main(int argc, char* argv[])
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor( c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl; << std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N; flop += std::size_t(2) * gemm_descs[i].M * gemm_descs[i].K * gemm_descs[i].N;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
...@@ -168,7 +167,7 @@ int main(int argc, char* argv[]) ...@@ -168,7 +167,7 @@ int main(int argc, char* argv[])
} }
} }
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors_device.emplace_back( a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
...@@ -194,7 +193,7 @@ int main(int argc, char* argv[]) ...@@ -194,7 +193,7 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto argument = auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); gemm.MakeArgument(p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -219,7 +218,7 @@ int main(int argc, char* argv[]) ...@@ -219,7 +218,7 @@ int main(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
......
...@@ -8,12 +8,6 @@ namespace ck { ...@@ -8,12 +8,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
};
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -42,29 +36,6 @@ template <typename AElementwiseOperation, ...@@ -42,29 +36,6 @@ template <typename AElementwiseOperation,
using DeviceGemmPtr = std::unique_ptr< using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGroupedGemmPtr = std::unique_ptr<
DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmDesc
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
};
struct GemmTransposeDesc
{
ck::index_t M, N, K;;
ck::index_t StrideA, StrideB, StrideC;
ck::index_t B, S, NumHead, HeadDim;
std::vector<ck::index_t> transpose;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGroupedGemmPtr = std::unique_ptr<
DeviceGroupedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmTranspose : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmTransposeDesc>& gemm_transpose_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGroupedGemmTransposePtr = std::unique_ptr<
DeviceGroupedGemmTranspose<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_grouped_gemm.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -349,7 +349,7 @@ struct DeviceGroupedGemmXdl ...@@ -349,7 +349,7 @@ struct DeviceGroupedGemmXdl
Argument(std::vector<const void*>& p_a, Argument(std::vector<const void*>& p_a,
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes, std::vector<GemmDesc>& gemm_descs,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -365,7 +365,7 @@ struct DeviceGroupedGemmXdl ...@@ -365,7 +365,7 @@ struct DeviceGroupedGemmXdl
gemm_descs_args_workspace_ = nullptr; gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_b.size()) && group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
...@@ -376,15 +376,15 @@ struct DeviceGroupedGemmXdl ...@@ -376,15 +376,15 @@ struct DeviceGroupedGemmXdl
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
const index_t M = gemm_shapes[i].M; const index_t M = gemm_descs[i].M;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_descs[i].N;
const index_t K = gemm_shapes[i].K; const index_t K = gemm_descs[i].K;
const index_t StrideA = gemm_shapes[i].StrideA; const index_t StrideA = gemm_descs[i].StrideA;
const index_t StrideB = gemm_shapes[i].StrideB; const index_t StrideB = gemm_descs[i].StrideB;
const index_t StrideC = gemm_shapes[i].StrideC; const index_t StrideC = gemm_descs[i].StrideC;
const auto a_grid_desc_k0_m_k1_ = const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
...@@ -580,12 +580,12 @@ struct DeviceGroupedGemmXdl ...@@ -580,12 +580,12 @@ struct DeviceGroupedGemmXdl
static auto MakeArgument(std::vector<const void*>& p_a, static auto MakeArgument(std::vector<const void*>& p_a,
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape> gemm_shapes, std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; return Argument{p_a, p_b, p_c, gemm_descs, 1, 1, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -594,14 +594,14 @@ struct DeviceGroupedGemmXdl ...@@ -594,14 +594,14 @@ struct DeviceGroupedGemmXdl
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b, std::vector<const void*>& p_b,
std::vector<void*>& p_c, std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>( return std::make_unique<Argument>(
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); p_a, p_b, p_c, gemm_descs, 1, 1, a_element_op, b_element_op, c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int group_count = rand() % 10 + 1; int group_count = rand() % 10 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b; std::vector<const void*> p_a, p_b;
std::vector<void*> p_c; std::vector<void*> p_c;
gemm_shapes.reserve(group_count); gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
...@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K; int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
int CStride = std::is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value ? N : M; int CStride = std::is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value ? N : M;
gemm_shapes.push_back({M, N, K, AStride, BStride, CStride}); gemm_descs.push_back({M, N, K, AStride, BStride, CStride});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count);
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
} }
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors_device.emplace_back( a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize()));
...@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); p_a, p_b, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get())); DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
...@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
invoker_ptr->Run(argument_ptr.get()); invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment