Commit 66a183d7 authored by ThomasNing's avatar ThomasNing
Browse files

Update some of the code to better format

parent cca67d13
......@@ -66,6 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marke
-Werror
-Wno-option-ignored
-Wsign-compare
......
......@@ -46,8 +46,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
......
......@@ -41,7 +41,7 @@ struct GemmPipelineAgBgCrImplBase
{
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
......@@ -51,11 +51,10 @@ struct GemmPipelineAgBgCrImplBase
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
......
......@@ -29,14 +29,14 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
make_tuple(number<kMPerBlock * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock>{} / 8, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
......@@ -52,14 +52,14 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
......@@ -69,16 +69,16 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16);
return smem_size_b;
}
......@@ -87,9 +87,8 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
constexpr index_t smem_size = smem_size_a + smem_size_b;
return smem_size;
return smem_size_a + smem_size_b;
}
template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment