Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
405c05c0
Commit
405c05c0
authored
Nov 27, 2024
by
dummycoderfe
Browse files
add prefetch and fix output err
parent
6c270303
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
122 additions
and
93 deletions
+122
-93
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+114
-87
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+6
-4
No files found.
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
405c05c0
...
...
@@ -205,8 +205,8 @@ struct BlockGemmARegBRegCRegV2
}
// Prefetch lds
template
<
typename
BlockWindow
Tmp
,
typename
BlockTensor
>
CK_TILE_DEVICE
static
auto
PrefetchLds
(
const
BlockWindow
Tmp
&
block_window
,
BlockTensor
&
block_tensor
)
template
<
typename
BlockWindow
,
typename
BlockTensor
>
CK_TILE_DEVICE
static
auto
PrefetchLds
(
const
BlockWindow
&
block_window
,
BlockTensor
&
block_tensor
)
{
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
//.get_static_tile_distribution_encoding()
return
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
405c05c0
...
...
@@ -37,17 +37,17 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_least_multiple
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
2
+
integer_least_multiple
(
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
2
;
}
//
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
//
{
//
return integer_least_multiple(
//
sizeof(ADataType) *
//
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
//
16) * 2 +
//
integer_least_multiple(
//
sizeof(BDataType) *
//
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
//
16) * 2;
//
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
...
...
@@ -91,46 +91,78 @@ struct GemmPipelineAGmemBGmemCRegV1
kNPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kKPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
////////////// global window & register /////////////////
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A register tile for global load
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
// A tile in LDS
// global prefetch 0
// global read 0
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbb\n");
// constexpr auto span_2d2 = decltype(b_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
////////////// LDS desc, window & register /////////////////
// AB LDS desc
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
constexpr
index_t
b_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
BDataType
)
*
b_lds_block_desc
.
get_element_space_size
(),
16
);
// A tile in LDS view
ADataType
*
p_a_lds0
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
);
// B tile in LDS
BDataType
*
p_b_lds0
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
);
BDataType
*
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_b_lds0
)
+
b_lds_block_space_size_aligned
);
ADataType
*
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_a_lds0
)
+
a_lds_block_space_size_aligned
);
auto
a_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds0
,
a_lds_block_desc
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
a_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds1
,
a_lds_block_desc
);
// B tile in LDS view
BDataType
*
p_b_lds0
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_a_lds1
)
+
a_lds_block_space_size_aligned
);
BDataType
*
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_b_lds0
)
+
b_lds_block_space_size_aligned
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_store_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
a_store_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_store_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
...
...
@@ -154,41 +186,46 @@ struct GemmPipelineAGmemBGmemCRegV1
// Acc register tile
auto
c_block_tile
=
Policy
::
template
BlockGemm
<
Problem
>
::
MakeCBlockTile
();
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// a b register tile
auto
a_prefetch_tile0
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
a_prefetch_tile1
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
b_prefetch_tile0
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
auto
b_prefetch_tile1
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
// a b register tile for lds prefetch & mfma
auto
a_block_tile0
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
a_block_tile1
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
b_block_tile0
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
auto
b_block_tile1
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
// prefetch
// global read 0
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_store_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_sync_lds
();
// global read 1
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_sync_lds
();
// local prefetch 0
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_
prefetch
_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_
prefetch
_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_
block
_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_
block
_tile0
);
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill
(
a_store_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window1
,
b_global_load_tile
,
b_element_func
);
...
...
@@ -197,37 +234,31 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
index_t
iCounter
=
num_loop
-
1
;
while
(
iCounter
>
2
)
index_t
iCounter
=
num_loop
-
2
;
while
(
iCounter
>
1
)
{
// ping
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_prefetch_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_prefetch_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_store_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
__builtin_amdgcn_sched_barrier
(
0
);
// pong
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_
prefetch
_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_
prefetch
_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_
block
_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_
block
_tile0
);
LocalPrefill
(
a_store_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_prefetch_tile1
,
b_prefetch_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
iCounter
-=
2
;
}
...
...
@@ -236,38 +267,34 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_prefetch_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_prefetch_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_store_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_prefetch_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_prefetch_tile0
);
block_gemm
(
c_block_tile
,
a_prefetch_tile1
,
b_prefetch_tile1
);
__builtin_amdgcn_sched_barrier
(
0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
//1
{
block_gemm
(
c_block_tile
,
a_
prefetch
_tile0
,
b_
prefetch
_tile0
);
block_gemm
(
c_block_tile
,
a_
block
_tile0
,
b_
block
_tile0
);
}
//tail 2
}
else
{
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_prefetch_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_prefetch_tile1
);
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_gemm
(
c_block_tile
,
a_
prefetch
_tile1
,
b_
prefetch
_tile1
);
block_gemm
(
c_block_tile
,
a_
block
_tile1
,
b_
block
_tile1
);
}
}
return
c_block_tile
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
405c05c0
...
...
@@ -97,16 +97,18 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
)
*
2
;
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
)
*
2
;
return
smem_size_b
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment