Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
532eb870
Commit
532eb870
authored
Nov 30, 2024
by
coderfeli
Browse files
fix warning and use default epilog and one out
parent
613e45b9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
70 additions
and
121 deletions
+70
-121
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+9
-8
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
+10
-8
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+12
-52
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+24
-53
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+15
-0
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
532eb870
...
@@ -48,14 +48,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -48,14 +48,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogueV2
<
ck_tile
::
CShuffleEpilogueV2Problem
<
AccDataType
,
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
CDataType
,
// CDataType,
M_Warp
*
N_Warp
*
K_Warp
*
Warp_Size
,
// M_Warp * N_Warp * K_Warp * Warp_Size,
TilePartitioner
::
kM
,
// 64,
TilePartitioner
::
kN
,
// TilePartitioner::kN,
kPadM
,
// kPadM,
kPadN
>>
;
// kPadN>>;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
View file @
532eb870
...
@@ -32,8 +32,8 @@ struct CShuffleEpilogueV2Problem
...
@@ -32,8 +32,8 @@ struct CShuffleEpilogueV2Problem
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOLdsBlockDescriptor
()
{
{
static
constexpr
index_t
kMPerBlock
=
64
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
...
@@ -45,10 +45,10 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
...
@@ -45,10 +45,10 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeODramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeODramTileDistribution
()
{
{
static
constexpr
index_t
kMPerBlock
=
64
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
constexpr
index_t
WaveSize
=
get_warp_size
();
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
// using OLayout = remove_cvref_t<typename Problem::OLayout>;
// using OLayout = remove_cvref_t<typename Problem::OLayout>;
...
@@ -83,8 +83,9 @@ struct CShuffleEpilogueV2
...
@@ -83,8 +83,9 @@ struct CShuffleEpilogueV2
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
static
constexpr
bool
kMPerBlock
=
64
;
static
constexpr
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
bool
kNPerBlock
=
Problem
::
kNPerBlock
;
// static constexpr bool kMPerBlock = 64;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
65536
;}
//kMPerBlock * kNPerBlock * sizeof(ODataType); }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
65536
;}
//kMPerBlock * kNPerBlock * sizeof(ODataType); }
...
@@ -104,6 +105,7 @@ struct CShuffleEpilogueV2
...
@@ -104,6 +105,7 @@ struct CShuffleEpilogueV2
auto
o_dram_distri
=
MakeODramTileDistribution
<
Problem
>
();
auto
o_dram_distri
=
MakeODramTileDistribution
<
Problem
>
();
auto
o_dram_tile
=
load_tile
(
make_tile_window
(
o_lds_window0
,
o_dram_distri
));
auto
o_dram_tile
=
load_tile
(
make_tile_window
(
o_lds_window0
,
o_dram_distri
));
store_tile
(
o_dram_window_tmp
,
o_dram_tile
);
store_tile
(
o_dram_window_tmp
,
o_dram_tile
);
block_sync_lds
();
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
532eb870
...
@@ -212,61 +212,21 @@ struct GemmKernel
...
@@ -212,61 +212,21 @@ struct GemmKernel
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
using
CSubTileDistr
=
decltype
(
GemmPipeline
::
MakeCBlockSubTile
());
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
CSubTileDistr
c_sub_tile
;
// using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
// printf("!!!!!!!!!!!!!!!!!!!!");
// c_sub_tile.get_tile_distribution().print();
// static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
// if (threadIdx.x==0) {
// {
// printf("!!!!!!!!!!!!!!!!!!!!~~~ %d %d\n", c_block_tile.get_tile_distribution().get_num_of_dimension_y(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// CSubTileDistr c_sub_tile;
// // c_block_tile.get_tile_distribution().print();
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_y_index_zeros.print();
// c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// c_sub_y_lengths.print();
// merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
// }
// merge_sequences(sequence<1>{}, c_sub_y_lengths));
// auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// if (threadIdx.x == 0) {
// c_sub_y_index_zeros.print();
// printf("\n");
// c_sub_y_lengths.print();
// printf("\n");
// printf("%d %d\n", GemmPipeline::NumCSubTile(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// }
// auto tbuf = c_block_tile.get_thread_buffer();
// for (index_t i = 0; i < tbuf.size(); i++) {
// if (threadIdx.x<16) {
// tbuf.set_as(i, float(threadIdx.x * 100 + i));
// } else {
// tbuf.set_as(i, float(threadIdx.x));
// }
// }
// c_block_tile.get_thread_buffer() = tbuf;
// if (threadIdx.x ==0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
static_for
<
0
,
GemmPipeline
::
NumCSubTile
(),
1
>
{}([
&
](
auto
i_m0
)
{
constexpr
auto
c_sub_y_index_zeros
=
uniform_sequence_gen_t
<
c_sub_tile
.
get_tile_distribution
().
get_num_of_dimension_y
(),
0
>
{};
constexpr
auto
c_sub_y_lengths
=
to_sequence
(
c_sub_tile
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
());
c_sub_tile
.
get_thread_buffer
()
=
c_block_tile
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m0
>
{},
c_sub_y_index_zeros
),
merge_sequences
(
sequence
<
1
>
{},
c_sub_y_lengths
));
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_sub_tile
,
smem_ptr
);
// EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
move_tile_window
(
CBlockWindow_pad
,
{
TilePartitioner
::
kM
/
GemmPipeline
::
NumCSubTile
(),
0
});
// move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
// });
});
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
532eb870
...
@@ -249,36 +249,6 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -249,36 +249,6 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
// LDS write 1
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
...
@@ -321,29 +291,30 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -321,29 +291,30 @@ struct GemmPipelineAGmemBGmemCRegV1
}
}
//tail 3
//tail 3
if
(
iCounter
==
1
)
{
// if (iCounter == 1) {
// 3
// // 3
{
// {
block_sync_lds
();
// block_sync_lds();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
// LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
// LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// }
// 2
// // 2
{
// {
block_sync_lds
();
// block_sync_lds();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
// }
//1
// //1
{
// {
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// }
//tail 2
// //tail 2
}
else
{
// } else
{
{
{
block_sync_lds
();
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
532eb870
...
@@ -70,6 +70,21 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -70,6 +70,21 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
a_lds_block_desc
;
return
a_lds_block_desc
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockLinearDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
return
a_lds_block_desc_0
;
}
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment