"vscode:/vscode.git/clone" did not exist on "ea8cc8cf69212df23b98333c28e6c4a2c4dcf279"
Commit 9e3825a2 authored by Paul's avatar Paul
Browse files

Format

parent 028bb4b6
......@@ -32,11 +32,14 @@ extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_
static_assert(desc.IsValid(), "Invalid ck gemm.");
${template}::Run(desc,
a,
b,
ck::make_tuple(),
c);
if constexpr(desc.IsValid())
{
${template}::Run(desc,
a,
b,
ck::make_tuple(),
c);
}
}
)__ck__";
......
......@@ -503,7 +503,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
......@@ -524,7 +523,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
......@@ -723,6 +721,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
DsDesc{});
}
using AGridDesc_M_K =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
using EGridDesc_M_N =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
......@@ -806,7 +811,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
constexpr bool IsValid() const
{
return GridwiseGemm::CheckValidity((a_grid_desc_m_k),
return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
......@@ -844,7 +849,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
assert(desc.is_valid);
assert(desc.IsValid());
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
......
......@@ -161,7 +161,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
......@@ -317,7 +317,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
private:
index_t M01_;
......@@ -375,7 +378,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
......@@ -487,7 +490,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
......@@ -611,7 +614,7 @@ struct OffsettedBlockToCTileMap
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
......@@ -668,7 +671,7 @@ struct BlockToCTileMap_3DGrid_KSplit
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment