Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a8d88d8d
Commit
a8d88d8d
authored
Dec 02, 2024
by
coderfeli
Browse files
tmp before merge
parent
c7d08b7c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
28 deletions
+29
-28
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+8
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+19
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+0
-15
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+1
-1
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
a8d88d8d
...
...
@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
true
,
3
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
true
,
2
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
a8d88d8d
...
...
@@ -189,9 +189,16 @@ struct BlockGemmARegBRegCRegV2
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistribution
()
{
// M->N Warp
// using AWarpDstrEncoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, //<32>, <2, 8>
// tuple<sequence<2, 1>>,
// tuple<sequence<0, 0>>,
// sequence<2>,
// sequence<1>>;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
// <4, 2>, <2>
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
a8d88d8d
...
...
@@ -254,13 +254,17 @@ struct GemmPipelineAGmemBGmemCRegV1
// local prefetch 0
// a b register tile for lds prefetch & mfma
using
ALdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
using
BLdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
{}
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
{}
));
constexpr
auto
ALdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
())
{}
;
constexpr
auto
BLdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
())
{}
;
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
BLdsTile
b_block_tile0
;
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
// LDS write 1
...
...
@@ -281,7 +285,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
...
...
@@ -293,7 +298,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// pong
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
...
...
@@ -311,7 +317,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
...
...
@@ -320,7 +327,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
...
...
@@ -334,7 +342,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// //tail 2
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
a8d88d8d
...
...
@@ -70,21 +70,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
a_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockLinearDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
return
a_lds_block_desc_0
;
}
// 3d + padding
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
a8d88d8d
...
...
@@ -170,7 +170,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
// <32>, <2, 4>
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment