Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
613e45b9
Commit
613e45b9
authored
Nov 29, 2024
by
root
Browse files
cshuffle v2 result correct, but perf awful
parent
801f995c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
71 additions
and
51 deletions
+71
-51
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+11
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
+12
-44
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+46
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+1
-3
No files found.
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
613e45b9
...
@@ -201,4 +201,15 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
...
@@ -201,4 +201,15 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return
unpacks
;
return
unpacks
;
}
}
template
<
typename
StaticTensor
>
CK_TILE_DEVICE
void
dump_static_tensor
(
StaticTensor
&
t
){
constexpr
auto
span_2d
=
decltype
(
t
)
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
printf
(
"%f,"
,
type_convert
<
float
>
(
t
(
i_j_idx
)));
});
printf
(
"
\n
"
);
});
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
View file @
613e45b9
...
@@ -22,8 +22,8 @@ struct CShuffleEpilogueV2Problem
...
@@ -22,8 +22,8 @@ struct CShuffleEpilogueV2Problem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
// static constexpr bool UseRawStore = UseRawStore_;
// static constexpr bool UseRawStore = UseRawStore_;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
MPerBlock
=
kM_
;
static
constexpr
index_t
k
MPerBlock
=
kM_
;
static
constexpr
index_t
NPerBlock
=
kN_
;
static
constexpr
index_t
k
NPerBlock
=
kN_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
};
...
@@ -32,14 +32,12 @@ struct CShuffleEpilogueV2Problem
...
@@ -32,14 +32,12 @@ struct CShuffleEpilogueV2Problem
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOLdsBlockDescriptor
()
{
{
static
constexpr
index_t
kMPerBlock
=
Problem
::
MPerBlock
;
static
constexpr
index_t
kMPerBlock
=
64
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
NPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
k
NPerBlock
;
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
1
>
{}));
number
<
1
>
{},
number
<
1
>
{});
}
}
...
@@ -47,8 +45,8 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
...
@@ -47,8 +45,8 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeODramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeODramTileDistribution
()
{
{
static
constexpr
index_t
kMPerBlock
=
Problem
::
MPerBlock
;
static
constexpr
index_t
kMPerBlock
=
64
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
NPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
k
NPerBlock
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
WaveSize
=
get_warp_size
();
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
...
@@ -85,36 +83,17 @@ struct CShuffleEpilogueV2
...
@@ -85,36 +83,17 @@ struct CShuffleEpilogueV2
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
static
constexpr
bool
kMPerBlock
=
Problem
::
MPerBlock
;
static
constexpr
bool
kMPerBlock
=
64
;
static
constexpr
bool
kNPerBlock
=
Problem
::
NPerBlock
;
static
constexpr
bool
kNPerBlock
=
Problem
::
kNPerBlock
;
// constexpr auto a_warp_y_lengths =
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
65536
;}
//kMPerBlock * kNPerBlock * sizeof(ODataType); }
// to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
// merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// dst_out.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
// dst_warp_tensor.get_thread_buffer());
// });
// });
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
kMPerBlock
*
kNPerBlock
*
sizeof
(
ODataType
);
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
,
void
*
p_smem
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
,
void
*
p_smem
)
{
{
block_sync_lds
();
auto
o_lds_tile
=
cast_tile
<
ODataType
>
(
o_acc_tile
);
auto
o_lds_tile
=
cast_tile
<
ODataType
>
(
o_acc_tile
);
constexpr
auto
o_lds_block_desc
=
MakeOLdsBlockDescriptor
<
Problem
>
();
constexpr
auto
o_lds_block_desc
=
MakeOLdsBlockDescriptor
<
Problem
>
();
auto
o_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
static_cast
<
ODataType
*>
(
p_smem
),
o_lds_block_desc
);
auto
o_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
static_cast
<
ODataType
*>
(
p_smem
),
o_lds_block_desc
);
...
@@ -122,17 +101,6 @@ struct CShuffleEpilogueV2
...
@@ -122,17 +101,6 @@ struct CShuffleEpilogueV2
store_tile
(
o_lds_window0
,
o_lds_tile
);
store_tile
(
o_lds_window0
,
o_lds_tile
);
block_sync_lds
();
block_sync_lds
();
// if (threadIdx.x == 0) {
// printf("%f, %f\n",type_convert<float>(static_cast<ODataType*>(p_smem)[32767]), type_convert<float>(static_cast<ODataType*>(p_smem)[32768]));
// constexpr auto span_2d = decltype(o_lds_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(o_lds_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
auto
o_dram_distri
=
MakeODramTileDistribution
<
Problem
>
();
auto
o_dram_distri
=
MakeODramTileDistribution
<
Problem
>
();
auto
o_dram_tile
=
load_tile
(
make_tile_window
(
o_lds_window0
,
o_dram_distri
));
auto
o_dram_tile
=
load_tile
(
make_tile_window
(
o_lds_window0
,
o_dram_distri
));
store_tile
(
o_dram_window_tmp
,
o_dram_tile
);
store_tile
(
o_dram_window_tmp
,
o_dram_tile
);
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
613e45b9
...
@@ -175,7 +175,7 @@ struct BlockGemmARegBRegCRegV2
...
@@ -175,7 +175,7 @@ struct BlockGemmARegBRegCRegV2
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
2
>
,
sequence
<
0
>>
{};
sequence
<
0
>>
{};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
613e45b9
...
@@ -213,16 +213,59 @@ struct GemmKernel
...
@@ -213,16 +213,59 @@ struct GemmKernel
{
i_m
,
i_n
});
{
i_m
,
i_n
});
using
CSubTileDistr
=
decltype
(
GemmPipeline
::
MakeCBlockSubTile
());
using
CSubTileDistr
=
decltype
(
GemmPipeline
::
MakeCBlockSubTile
());
CSubTileDistr
c_sub_tile
;
// printf("!!!!!!!!!!!!!!!!!!!!");
// c_sub_tile.get_tile_distribution().print();
// if (threadIdx.x==0) {
// printf("!!!!!!!!!!!!!!!!!!!!~~~ %d %d\n", c_block_tile.get_tile_distribution().get_num_of_dimension_y(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// // c_block_tile.get_tile_distribution().print();
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_y_index_zeros.print();
// c_sub_y_lengths.print();
// }
// auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// if (threadIdx.x == 0) {
// c_sub_y_index_zeros.print();
// printf("\n");
// c_sub_y_lengths.print();
// printf("\n");
// printf("%d %d\n", GemmPipeline::NumCSubTile(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// }
// auto tbuf = c_block_tile.get_thread_buffer();
// for (index_t i = 0; i < tbuf.size(); i++) {
// if (threadIdx.x<16) {
// tbuf.set_as(i, float(threadIdx.x * 100 + i));
// } else {
// tbuf.set_as(i, float(threadIdx.x));
// }
// }
// c_block_tile.get_thread_buffer() = tbuf;
// if (threadIdx.x ==0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
static_for
<
0
,
GemmPipeline
::
NumCSubTile
(),
1
>
{}([
&
](
auto
i_m0
)
static_for
<
0
,
GemmPipeline
::
NumCSubTile
(),
1
>
{}([
&
](
auto
i_m0
)
{
{
auto
c_sub_tile
=
make_static_distributed_tensor
<
CDataType
>
(
CSubTileDistr
{});
constexpr
auto
c_sub_y_index_zeros
=
uniform_sequence_gen_t
<
c_sub_tile
.
get_tile_distribution
().
get_num_of_dimension_y
(),
0
>
{};
constexpr
auto
c_sub_y_index_zeros
=
uniform_sequence_gen_t
<
CSubTileDistr
::
NDimY
,
0
>
{};
constexpr
auto
c_sub_y_lengths
=
to_sequence
(
c_sub_tile
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_sub_y_lengths
=
to_sequence
(
CSubTileDistr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
c_sub_tile
.
get_thread_buffer
()
=
c_block_tile
.
get_y_sliced_thread_data
(
c_sub_tile
.
get_thread_buffer
()
=
c_block_tile
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m0
>
{},
c_sub_y_index_zeros
),
merge_sequences
(
sequence
<
i_m0
>
{},
c_sub_y_index_zeros
),
merge_sequences
(
sequence
<
1
>
{},
c_sub_y_lengths
));
merge_sequences
(
sequence
<
1
>
{},
c_sub_y_lengths
));
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_sub_tile
,
smem_ptr
);
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_sub_tile
,
smem_ptr
);
move_tile_window
(
CBlockWindow_pad
,
{
TilePartitioner
::
kM
/
GemmPipeline
::
NumCSubTile
(),
0
});
move_tile_window
(
CBlockWindow_pad
,
{
TilePartitioner
::
kM
/
GemmPipeline
::
NumCSubTile
(),
0
});
});
});
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
613e45b9
...
@@ -11,15 +11,13 @@ namespace ck_tile {
...
@@ -11,15 +11,13 @@ namespace ck_tile {
// A Tile Window: global memory
// A Tile Window: global memory
// B Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
_
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV1
struct
GemmPipelineAGmemBGmemCRegV1
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
Policy
=
Policy_
;
using
Problem
=
Problem
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment