Commit 114f9298 authored by ltqin's avatar ltqin
Browse files

using atomic

parent b7ec2078
...@@ -30,9 +30,9 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< ...@@ -30,9 +30,9 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//#################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //#################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //#################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, false, 1>, DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 360>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, true, 360>, DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 480>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, true, 480> DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
// clang-format on // clang-format on
>; >;
#else #else
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 0
#endif
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -49,7 +53,6 @@ template <typename ADataType, ...@@ -49,7 +53,6 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM, bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
bool IsSplitK,
ck::index_t DesiredGridSize> ck::index_t DesiredGridSize>
struct DeviceGemmSplitKXdl : public DeviceGemm struct DeviceGemmSplitKXdl : public DeviceGemm
{ {
...@@ -63,7 +66,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm ...@@ -63,7 +66,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static auto static auto
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
{ {
assert(K % K1 == 0); assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch); const index_t K0 = KPad / (K1 * KBatch);
...@@ -96,7 +99,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm ...@@ -96,7 +99,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static auto static auto
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
{ {
assert(K % K1 == 0); assert(KPad % (K1 * KBatch) == 0);
const index_t K0 = KPad / (K1 * KBatch); const index_t K0 = KPad / (K1 * KBatch);
...@@ -141,8 +144,6 @@ struct DeviceGemmSplitKXdl : public DeviceGemm ...@@ -141,8 +144,6 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static auto GetKBatchAndKPad(index_t M, index_t N, index_t K) static auto GetKBatchAndKPad(index_t M, index_t N, index_t K)
{ {
if(!IsSplitK)
return std::make_tuple(1, K);
const auto GridMN = M * N / (MPerBlock * NPerBlock); const auto GridMN = M * N / (MPerBlock * NPerBlock);
const index_t KBatch = std::max(DesiredGridSize / GridMN, 1); const index_t KBatch = std::max(DesiredGridSize / GridMN, 1);
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
...@@ -405,18 +406,8 @@ struct DeviceGemmSplitKXdl : public DeviceGemm ...@@ -405,18 +406,8 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) const auto Run = [&](const auto& kernel) {
{ #if CK_RUN_KERNEL_AND_TIME
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
...@@ -429,31 +420,82 @@ struct DeviceGemmSplitKXdl : public DeviceGemm ...@@ -429,31 +420,82 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
#else
nrepeat++;
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_);
#endif
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
true>;
Run(kernel);
}
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4< if(kbatch == 1)
GridwiseGemm, {
ADataType, // TODO: distiguish A/B datatype const auto kernel = kernel_gemm_xdlops_v2r4<
CDataType, GridwiseGemm,
remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>, ADataType, // TODO: distiguish A/B datatype
remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>, CDataType,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>, remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
false>; remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
ave_time = launch_and_time_kernel(kernel, false>;
nrepeat,
dim3(grid_size), Run(kernel);
dim3(BlockSize), }
0, else
arg.p_a_grid_, {
arg.p_b_grid_, const auto kernel = kernel_gemm_xdlops_v2r4<
arg.p_c_grid_, GridwiseGemmAtomicAdd,
arg.a_grid_desc_kbatch_k0_m_k1_, ADataType, // TODO: distiguish A/B datatype
arg.b_grid_desc_kbatch_k0_n_k1_, CDataType,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, remove_reference_t<DeviceGemmSplitKXdl::AGridDesc_K0_M_K1>,
arg.block_2_ctile_map_); remove_reference_t<DeviceGemmSplitKXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmSplitKXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmSplitKXdl::Block2CTileMap>,
false>;
Run(kernel);
}
} }
return ave_time; return ave_time;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment