Commit 62a860a5 authored by ltqin's avatar ltqin
Browse files

change desired gride size to kbatch

parent accb4ca5
......@@ -13,20 +13,19 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t desired_gride_size = 1) = 0;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -144,13 +144,11 @@ struct DeviceGemmSplitKXdl
}
}
static auto GetKBatchAndKPad(index_t M, index_t N, index_t K, index_t DesiredGridSize)
static auto GetKPad(index_t K, index_t KBatch)
{
const auto GridMN = M * N / (MPerBlock * NPerBlock);
const index_t KBatch = std::max(DesiredGridSize / GridMN, 1);
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return std::make_tuple(KBatch, KPad);
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
return KPad;
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
......@@ -262,7 +260,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t desired_grid_size)
index_t k_batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -276,16 +274,14 @@ struct DeviceGemmSplitKXdl
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
desired_grid_size_{desired_grid_size}
k_batch_{k_batch}
{
int KBatch = 1, KPad = K;
std::tie(KBatch, KPad) =
DeviceGemmSplitKXdl::GetKBatchAndKPad(M, N, K, desired_grid_size_);
int KPad = DeviceGemmSplitKXdl::GetKPad(K, k_batch_);
a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmSplitKXdl::MakeAGridDescriptor_KBatch_K0_M_K1(
M, K, StrideA, KBatch, KPad);
M, K, StrideA, k_batch_, KPad);
b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmSplitKXdl::MakeBGridDescriptor_KBatch_K0_N_K1(
K, N, StrideB, KBatch, KPad);
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmSplitKXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
......@@ -298,7 +294,7 @@ struct DeviceGemmSplitKXdl
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, KBatch);
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
......@@ -316,7 +312,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t desired_grid_size_;
index_t k_batch_;
};
// Invoker
......@@ -526,7 +522,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t desired_grid_Size)
index_t KBatch)
{
return Argument{p_a,
p_b,
......@@ -542,7 +538,7 @@ struct DeviceGemmSplitKXdl
a_element_op,
b_element_op,
c_element_op,
desired_grid_Size};
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -560,7 +556,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t desired_gride_size = 1) override
ck::index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -576,7 +572,7 @@ struct DeviceGemmSplitKXdl
a_element_op,
b_element_op,
c_element_op,
desired_gride_size);
KBatch);
}
// polymorphic
......
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_xdl_splitk_instance.hpp"
#include "device_gemm_splitk_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -95,7 +95,7 @@ void profile_gemm_impl(int do_verification,
int StrideA,
int StrideB,
int StrideC,
int DesiredGridSize = 1)
int KBatch = 1)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -156,7 +156,7 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
if(DesiredGridSize > 1 && is_same<ADataType, float>::value)
if(KBatch > 1 && is_same<ADataType, float>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
......@@ -195,7 +195,7 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize);
KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......
......@@ -48,7 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: desired grid size\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
}
......@@ -66,9 +66,9 @@ int profile_gemm(int argc, char* argv[])
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
int DesiredGridSize = 1;
int KBatch = 1;
if(argc == 15)
DesiredGridSize = std::stoi(argv[14]);
KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -164,7 +164,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
......@@ -184,7 +184,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
......@@ -204,7 +204,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
......@@ -224,7 +224,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
DesiredGridSize);
KBatch);
}
else
{
......
......@@ -11,7 +11,7 @@
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
#include "device_gemm_xdl_instance.hpp"
#include "device_gemm_splitk_xdl_instance.hpp"
#include "device_gemm_splitk_xdl.hpp"
enum GemmMatrixLayout
......@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, n] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, n] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC DesiredGridSize\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return 1;
}
......@@ -122,10 +122,10 @@ int main(int argc, char* argv[])
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int DesiredGridSize = std::stoi(argv[8]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int KBatch = std::stoi(argv[8]);
if(layout > 3 || layout < 0)
{
......@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize);
KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment