Commit 66a183d7 authored by ThomasNing's avatar ThomasNing
Browse files

Update some of the code to better format

parent cca67d13
...@@ -66,6 +66,7 @@ else() ...@@ -66,6 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marke
-Werror -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
......
...@@ -46,8 +46,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -46,8 +46,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level // Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
......
...@@ -41,7 +41,7 @@ struct GemmPipelineAgBgCrImplBase ...@@ -41,7 +41,7 @@ struct GemmPipelineAgBgCrImplBase
{ {
load_tile(dst_block_tile, lds_tile_window); load_tile(dst_block_tile, lds_tile_window);
} }
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{ {
// A tile in LDS // A tile in LDS
...@@ -51,11 +51,10 @@ struct GemmPipelineAgBgCrImplBase ...@@ -51,11 +51,10 @@ struct GemmPipelineAgBgCrImplBase
// TODO: LDS alignment should come from Policy! // TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
16;
// B tile in LDS // B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>( BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned)); static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>(); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc); auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
......
...@@ -29,14 +29,14 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -29,14 +29,14 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
// TODO: this 8 is AK1! should be a policy parameter! // TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<kMPerBlock * 8>{}, number<8>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor( constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0, a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock), make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))), make_merge_transform(make_tuple(number<kKPerBlock>{} / 8, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -52,14 +52,14 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -52,14 +52,14 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0, b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock), make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))), make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -69,16 +69,16 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -69,16 +69,16 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16);
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16);
return smem_size_b; return smem_size_b;
} }
...@@ -87,9 +87,8 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -87,9 +87,8 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
constexpr index_t smem_size = smem_size_a + smem_size_b;
return smem_size; return smem_size_a + smem_size_b;
} }
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment