Commit 9e3825a2 authored by Paul's avatar Paul
Browse files

Format

parent 028bb4b6
...@@ -32,11 +32,14 @@ extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_ ...@@ -32,11 +32,14 @@ extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_
static_assert(desc.IsValid(), "Invalid ck gemm."); static_assert(desc.IsValid(), "Invalid ck gemm.");
if constexpr(desc.IsValid())
{
${template}::Run(desc, ${template}::Run(desc,
a, a,
b, b,
ck::make_tuple(), ck::make_tuple(),
c); c);
}
} }
)__ck__"; )__ck__";
......
...@@ -503,7 +503,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -503,7 +503,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// check vector load/store // check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
...@@ -524,7 +523,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -524,7 +523,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
return false; return false;
} }
// check vector laod of B // check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
...@@ -723,6 +721,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -723,6 +721,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
DsDesc{}); DsDesc{});
} }
using AGridDesc_M_K =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
using EGridDesc_M_N =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
using AGridDesc_AK0_M_AK1 = using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>; DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
...@@ -806,7 +811,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -806,7 +811,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
constexpr bool IsValid() const constexpr bool IsValid() const
{ {
return GridwiseGemm::CheckValidity((a_grid_desc_m_k), return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k, b_grid_desc_n_k,
ds_grid_desc_m_n, ds_grid_desc_m_n,
e_grid_desc_m_n, e_grid_desc_m_n,
...@@ -844,7 +849,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -844,7 +849,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType* __restrict__ p_e_grid) EDataType* __restrict__ p_e_grid)
{ {
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
assert(desc.is_valid); assert(desc.IsValid());
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
GridwiseGemm::template Run<true>(p_a_grid, GridwiseGemm::template Run<true>(p_a_grid,
......
...@@ -161,7 +161,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -161,7 +161,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{ {
return true; return true;
} }
...@@ -317,7 +317,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -317,7 +317,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
private: private:
index_t M01_; index_t M01_;
...@@ -375,7 +378,7 @@ struct BlockToCTileMap_M00_N00_M01_N01 ...@@ -375,7 +378,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
return true; return true;
} }
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
if constexpr(DeviceCTileIndexCheck) if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel return true; // validity check moved to kernel
...@@ -487,7 +490,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -487,7 +490,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return true; return true;
} }
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
if constexpr(DeviceCTileIndexCheck) if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel return true; // validity check moved to kernel
...@@ -611,7 +614,7 @@ struct OffsettedBlockToCTileMap ...@@ -611,7 +614,7 @@ struct OffsettedBlockToCTileMap
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
} }
...@@ -668,7 +671,7 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -668,7 +671,7 @@ struct BlockToCTileMap_3DGrid_KSplit
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{ {
return true; return true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment