Commit 62a860a5 authored by ltqin's avatar ltqin
Browse files

change desired gride size to kbatch

parent accb4ca5
...@@ -13,8 +13,7 @@ template <typename AElementwiseOperation, ...@@ -13,8 +13,7 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator struct DeviceGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
...@@ -26,7 +25,7 @@ struct DeviceGemm : public BaseOperator ...@@ -26,7 +25,7 @@ struct DeviceGemm : public BaseOperator
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t desired_gride_size = 1) = 0; ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -144,13 +144,11 @@ struct DeviceGemmSplitKXdl ...@@ -144,13 +144,11 @@ struct DeviceGemmSplitKXdl
} }
} }
static auto GetKBatchAndKPad(index_t M, index_t N, index_t K, index_t DesiredGridSize) static auto GetKPad(index_t K, index_t KBatch)
{ {
const auto GridMN = M * N / (MPerBlock * NPerBlock);
const index_t KBatch = std::max(DesiredGridSize / GridMN, 1);
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1; const index_t KPad = KBatch * K0 * K1;
return std::make_tuple(KBatch, KPad); return KPad;
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
...@@ -262,7 +260,7 @@ struct DeviceGemmSplitKXdl ...@@ -262,7 +260,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t desired_grid_size) index_t k_batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -276,16 +274,14 @@ struct DeviceGemmSplitKXdl ...@@ -276,16 +274,14 @@ struct DeviceGemmSplitKXdl
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
desired_grid_size_{desired_grid_size} k_batch_{k_batch}
{ {
int KBatch = 1, KPad = K; int KPad = DeviceGemmSplitKXdl::GetKPad(K, k_batch_);
std::tie(KBatch, KPad) =
DeviceGemmSplitKXdl::GetKBatchAndKPad(M, N, K, desired_grid_size_);
a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmSplitKXdl::MakeAGridDescriptor_KBatch_K0_M_K1( a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmSplitKXdl::MakeAGridDescriptor_KBatch_K0_M_K1(
M, K, StrideA, KBatch, KPad); M, K, StrideA, k_batch_, KPad);
b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmSplitKXdl::MakeBGridDescriptor_KBatch_K0_N_K1( b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmSplitKXdl::MakeBGridDescriptor_KBatch_K0_N_K1(
K, N, StrideB, KBatch, KPad); K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmSplitKXdl::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmSplitKXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
...@@ -298,7 +294,7 @@ struct DeviceGemmSplitKXdl ...@@ -298,7 +294,7 @@ struct DeviceGemmSplitKXdl
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, KBatch); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
} }
} }
...@@ -316,7 +312,7 @@ struct DeviceGemmSplitKXdl ...@@ -316,7 +312,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t desired_grid_size_; index_t k_batch_;
}; };
// Invoker // Invoker
...@@ -526,7 +522,7 @@ struct DeviceGemmSplitKXdl ...@@ -526,7 +522,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t desired_grid_Size) index_t KBatch)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -542,7 +538,7 @@ struct DeviceGemmSplitKXdl ...@@ -542,7 +538,7 @@ struct DeviceGemmSplitKXdl
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
desired_grid_Size}; KBatch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -560,7 +556,7 @@ struct DeviceGemmSplitKXdl ...@@ -560,7 +556,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t desired_gride_size = 1) override ck::index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -576,7 +572,7 @@ struct DeviceGemmSplitKXdl ...@@ -576,7 +572,7 @@ struct DeviceGemmSplitKXdl
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
desired_gride_size); KBatch);
} }
// polymorphic // polymorphic
......
#pragma once #pragma once
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "device_gemm_xdl_splitk_instance.hpp" #include "device_gemm_splitk_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -95,7 +95,7 @@ void profile_gemm_impl(int do_verification, ...@@ -95,7 +95,7 @@ void profile_gemm_impl(int do_verification,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
int DesiredGridSize = 1) int KBatch = 1)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -156,7 +156,7 @@ void profile_gemm_impl(int do_verification, ...@@ -156,7 +156,7 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs; std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
if(DesiredGridSize > 1 && is_same<ADataType, float>::value) if(KBatch > 1 && is_same<ADataType, float>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>( add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
...@@ -195,7 +195,7 @@ void profile_gemm_impl(int do_verification, ...@@ -195,7 +195,7 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize); KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......
...@@ -48,7 +48,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -48,7 +48,7 @@ int profile_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: desired grid size\n"); printf("arg14: split k into mulitiple batch\n");
exit(1); exit(1);
} }
...@@ -66,9 +66,9 @@ int profile_gemm(int argc, char* argv[]) ...@@ -66,9 +66,9 @@ int profile_gemm(int argc, char* argv[])
const int StrideA = std::stoi(argv[11]); const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
int DesiredGridSize = 1; int KBatch = 1;
if(argc == 15) if(argc == 15)
DesiredGridSize = std::stoi(argv[14]); KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -164,7 +164,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -164,7 +164,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
DesiredGridSize); KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
...@@ -184,7 +184,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -184,7 +184,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
DesiredGridSize); KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
...@@ -204,7 +204,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -204,7 +204,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
DesiredGridSize); KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -224,7 +224,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -224,7 +224,7 @@ int profile_gemm(int argc, char* argv[])
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
DesiredGridSize); KBatch);
} }
else else
{ {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_gemm_xdl_instance.hpp" #include "device_gemm_splitk_xdl_instance.hpp"
#include "device_gemm_splitk_xdl.hpp" #include "device_gemm_splitk_xdl.hpp"
enum GemmMatrixLayout enum GemmMatrixLayout
...@@ -112,7 +112,7 @@ int main(int argc, char* argv[]) ...@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, n] * B[k, n] = C[m, n];\n"); printf(" 2: A[k, n] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, n] * B[n, k] = C[m, n])\n"); printf(" 3: A[k, n] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC DesiredGridSize\n"); printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return 1; return 1;
} }
...@@ -125,7 +125,7 @@ int main(int argc, char* argv[]) ...@@ -125,7 +125,7 @@ int main(int argc, char* argv[])
const int StrideA = std::stoi(argv[5]); const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]); const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]); const int StrideC = std::stoi(argv[7]);
const int DesiredGridSize = std::stoi(argv[8]); const int KBatch = std::stoi(argv[8]);
if(layout > 3 || layout < 0) if(layout > 3 || layout < 0)
{ {
...@@ -194,7 +194,7 @@ int main(int argc, char* argv[]) ...@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize); KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment