Commit d862fdf0 authored by ltqin's avatar ltqin
Browse files

add desiredgridsize parameter to ckProfiler

parent adc79bdd
#ifndef DEVICE_GEMM_XDL_INSTANCE
#define DEVICE_GEMM_XDL_INSTANCE
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification,
int K,
int StrideA,
int StrideB,
int StrideC)
int StrideC,
int DesiredGridSize = 1)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
gemm_ptrs);
if(DesiredGridSize > 1 && is_same<ADataType, float>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
gemm_ptrs);
}
if(gemm_ptrs.size() <= 0)
{
......@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification,
StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......
......@@ -35,7 +35,7 @@ enum GemmDataType
int profile_gemm(int argc, char* argv[])
{
if(argc != 14)
if(!(argc == 14 || argc == 15))
{
printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
......@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: desired grid size\n");
exit(1);
}
......@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[])
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
int DesiredGridSize = 1;
if(argc == 15)
DesiredGridSize = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
......@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
......@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
......@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment