"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c03f0e3c3d6e66b5ac3359fa3cbc5322c72568ce"
Commit 114f9298 authored by ltqin's avatar ltqin
Browse files

using atomic

parent b7ec2078
......@@ -30,9 +30,9 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//#################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, false, 1>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, true, 360>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, true, 480>
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 360>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 480>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
// clang-format on
>;
#else
......
......@@ -11,6 +11,10 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 0
#endif
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -49,7 +53,6 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN,
bool IsSplitK,
ck::index_t DesiredGridSize>
struct DeviceGemmSplitKXdl : public DeviceGemm
{
......@@ -63,7 +66,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static auto
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
{
assert(K % K1 == 0);
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
......@@ -96,7 +99,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static auto
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
{
assert(K % K1 == 0);
assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch);
......@@ -141,8 +144,6 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static auto GetKBatchAndKPad(index_t M, index_t N, index_t K)
{
if(!IsSplitK)
return std::make_tuple(1, K);
const auto GridMN = M * N / (MPerBlock * NPerBlock);
const index_t KBatch = std::max(DesiredGridSize / GridMN, 1);
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
......@@ -405,18 +406,8 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
float ave_time = 0;
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
true>;
const auto Run = [&](const auto& kernel) {
#if CK_RUN_KERNEL_AND_TIME
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
......@@ -429,31 +420,82 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_);
#else
nrepeat++;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_);
#endif
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
true>;
Run(kernel);
}
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_);
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
false>;
Run(kernel);
}
}
return ave_time;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment