Commit d862fdf0 authored by ltqin's avatar ltqin
Browse files

add desiredgridsize parameter to ckProfiler

parent adc79bdd
#ifndef DEVICE_GEMM_XDL_INSTANCE
#define DEVICE_GEMM_XDL_INSTANCE
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#pragma once #pragma once
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "device_gemm_xdl_splitk_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification, ...@@ -93,7 +94,8 @@ void profile_gemm_impl(int do_verification,
int K, int K,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC) int StrideC,
int DesiredGridSize = 1)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification, ...@@ -154,9 +156,18 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs; std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
ck::tensor_operation::device::device_gemm_instance:: if(DesiredGridSize > 1 && is_same<ADataType, float>::value)
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>( {
gemm_ptrs); ck::tensor_operation::device::device_gemm_instance::
add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
gemm_ptrs);
}
if(gemm_ptrs.size() <= 0) if(gemm_ptrs.size() <= 0)
{ {
...@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification, ...@@ -183,7 +194,8 @@ void profile_gemm_impl(int do_verification,
StrideC, StrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......
...@@ -35,7 +35,7 @@ enum GemmDataType ...@@ -35,7 +35,7 @@ enum GemmDataType
int profile_gemm(int argc, char* argv[]) int profile_gemm(int argc, char* argv[])
{ {
if(argc != 14) if(!(argc == 14 || argc == 15))
{ {
printf("arg1: tensor operation (gemm: GEMM)\n"); printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n");
...@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -48,6 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: desired grid size\n");
exit(1); exit(1);
} }
...@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[]) ...@@ -62,9 +63,12 @@ int profile_gemm(int argc, char* argv[])
const int N = std::stoi(argv[9]); const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]); const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]); const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
int DesiredGridSize = 1;
if(argc == 15)
DesiredGridSize = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
DesiredGridSize);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
...@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
DesiredGridSize);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
...@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
DesiredGridSize);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC,
DesiredGridSize);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment