Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b4887801
Commit
b4887801
authored
Sep 29, 2024
by
carlushuang
Browse files
tmp
parent
a5670e67
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
190 additions
and
78 deletions
+190
-78
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
...ile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
+87
-28
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
.../fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
+28
-50
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+33
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
+41
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
View file @
b4887801
...
...
@@ -45,6 +45,34 @@ struct BlockFmhaPipelineQRAsyncEx
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
Block_M0
=
BlockFmhaShape
::
Block_M0
;
static
constexpr
index_t
Block_N0
=
BlockFmhaShape
::
Block_N0
;
static
constexpr
index_t
Block_K0
=
BlockFmhaShape
::
Block_K0
;
static
constexpr
index_t
BlockWarps_M0
=
BlockFmhaShape
::
BlockWarps_M0
;
static
constexpr
index_t
BlockWarps_N0
=
BlockFmhaShape
::
BlockWarps_N0
;
static
constexpr
index_t
BlockWarps_K0
=
BlockFmhaShape
::
BlockWarps_K0
;
static
constexpr
index_t
Warps_M0
=
BlockFmhaShape
::
Warps_M0
;
static
constexpr
index_t
Warps_N0
=
BlockFmhaShape
::
Warps_N0
;
static
constexpr
index_t
Warps_K0
=
BlockFmhaShape
::
Warps_K0
;
static
constexpr
index_t
Repeat_M0
=
BlockFmhaShape
::
Repeat_M0
;
static
constexpr
index_t
Repeat_N0
=
BlockFmhaShape
::
Repeat_N0
;
static
constexpr
index_t
Repeat_K0
=
BlockFmhaShape
::
Repeat_K0
;
static
constexpr
index_t
Block_M1
=
BlockFmhaShape
::
Block_M1
;
static
constexpr
index_t
Block_N1
=
BlockFmhaShape
::
Block_N1
;
static
constexpr
index_t
Block_K1
=
BlockFmhaShape
::
Block_K1
;
static
constexpr
index_t
BlockWarps_M1
=
BlockFmhaShape
::
BlockWarps_M1
;
static
constexpr
index_t
BlockWarps_N1
=
BlockFmhaShape
::
BlockWarps_N1
;
static
constexpr
index_t
BlockWarps_K1
=
BlockFmhaShape
::
BlockWarps_K1
;
static
constexpr
index_t
Warps_M1
=
BlockFmhaShape
::
Warps_M1
;
static
constexpr
index_t
Warps_N1
=
BlockFmhaShape
::
Warps_N1
;
static
constexpr
index_t
Warps_K1
=
BlockFmhaShape
::
Warps_K1
;
static
constexpr
index_t
Repeat_M1
=
BlockFmhaShape
::
Repeat_M1
;
static
constexpr
index_t
Repeat_N1
=
BlockFmhaShape
::
Repeat_N1
;
static
constexpr
index_t
Repeat_K1
=
BlockFmhaShape
::
Repeat_K1
;
static
constexpr
index_t
UnrollStages
=
2
;
// pipeline unroll the gemm/softmax/gemm
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
...
...
@@ -176,11 +204,10 @@ struct BlockFmhaPipelineQRAsyncEx
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
LdsSeq
=
Policy
::
template
GetLdsBufferSequence
<
Problem
>();
// K tile in LDS
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
generate_tuple
(
auto
k_lds_store
=
[
&
](){
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -189,28 +216,64 @@ struct BlockFmhaPipelineQRAsyncEx
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
}();
auto
k_lds_Load_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>());
auto
k_lds_load
=
make_tile_window
(
k_lds_Load_view
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
k_lds_load
=
[
&
](){
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>()),
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
v_lds_store
=
[
&
](){
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemStoreDesc_V
<
Problem
>(
i_buf
)),
Policy
::
template
MakeSmemStoreDesc_V
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchV
>
{});
}();
auto
v_lds_load
=
[
&
](){
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>()),
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetBlockGemm_0
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetBlockGemm_1
<
Problem
>();
constexpr
auto
warp_gemm_0
=
Policy
::
template
GetWarpGemm_0
<
Problem
>();
constexpr
auto
warp_gemm_1
=
Policy
::
template
GetWarpGemm_1
<
Problem
>();
auto
gemm_0
=
[
&
](){
constexpr
index_t
total_repeats
=
Repeat_M0
*
Repeat_N0
*
Repeat_K0
;
// n*k*m, more relaxed ds_read
static_for
<
0
,
total_repeats
,
1
>
{}(
[
&
](
auto
i_r
){
constexpr
index_t
i_m
=
i_r
%
Repeat_M0
;
constexpr
index_t
i_k
=
(
i_r
/
Repeat_M0
)
%
Repeat_K0
;
constexpr
index_t
i_n
=
i_r
/
(
Repeat_M0
*
Repeat_K0
);
}
);
};
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
auto
q_dram_window
=
make_tile_window
_raw
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalDesc_Q
<
Problem
>());
q_dram_window
.
init_raw
();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
...
...
@@ -221,12 +284,8 @@ struct BlockFmhaPipelineQRAsyncEx
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_accs
=
generate_tuple
([
&
](
auto
)
{
return
SaccBlockTileType
{};
},
number
<
2
>
{});
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
using
SaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_0
<
Problem
>());
auto
s_accs
=
generate_tuple
([
&
](
auto
)
{
return
SaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_accs
));
...
...
@@ -234,14 +293,14 @@ struct BlockFmhaPipelineQRAsyncEx
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
.
Make
C
Block
Tile
());
using
OaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlock
GemmAccTile_1
<
Problem
>
());
// init Oacc, M, L
auto
o_accs
=
generate_tuple
([
&
](
auto
)
{
return
OaccBlockTileType
{};
},
number
<
2
>
{});
auto
ms
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
2
>
{});
auto
ls
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
2
>
{});
auto
o_accs
=
generate_tuple
([
&
](
auto
)
{
return
OaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ms
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ls
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
UnrollStages
,
1
>
{}([
&
](
auto
i
)
{
clear_tile
(
o_accs
(
i
));
set_tile
(
ms
(
i
),
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
ls
(
i
));
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
View file @
b4887801
...
...
@@ -102,28 +102,20 @@ struct BlockFmhaPipelineQRAsyncEx
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
BlockGemm_0
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
BlockGemm
AccTile
_0
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
GetWarpGemm_0
<
Problem
>
();
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_0
())
::
CWarpDstrEncoding
;
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M0
,
Problem
::
BlockFmhaShape
::
Block_N0
>
;
using
BlockWarps_
=
sequence
<
Problem
::
BlockFmhaShape
::
BlockWarps_M0
,
Problem
::
BlockFmhaShape
::
BlockWarps_N0
>
;
using
WarpTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M0
,
Problem
::
BlockFmhaShape
::
Warp_N0
>
;
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
SaccDataType
>
(
dstr
);
return
t
;
}
template
<
typename
Problem
>
...
...
@@ -451,13 +443,8 @@ struct BlockFmhaPipelineQRAsyncEx
{
if
constexpr
(
Problem
::
kHasDropout
)
{
constexpr
auto
gemm_0
=
QXPolicy
::
template
GetBlockGemm_0
<
Problem
>();
constexpr
auto
config
=
decltype
(
gemm_0
)
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kMPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_M0
*
Problem
::
BlockFmhaShape
::
Warp_M0
;
constexpr
index_t
kNPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_N0
*
Problem
::
BlockFmhaShape
::
Warp_N0
;
return
(
kMPerStep
+
1
)
*
kNPerStep
*
sizeof
(
uint8_t
);
}
...
...
@@ -622,29 +609,20 @@ struct BlockFmhaPipelineQRAsyncEx
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Get
BlockGemm_1
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
BlockGemm
AccTile
_1
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
GetWarpGemm_1
<
Problem
>
();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV2CustomPolicy
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_1
())
::
CWarpDstrEncoding
;
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M1
,
Problem
::
BlockFmhaShape
::
Block_N1
>
;
using
BlockWarps_
=
sequence
<
Problem
::
BlockFmhaShape
::
BlockWarps_M1
,
Problem
::
BlockFmhaShape
::
BlockWarps_N1
>
;
using
WarpTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M1
,
Problem
::
BlockFmhaShape
::
Warp_N1
>
;
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
OaccDataType
>
(
dstr
);
return
t
;
}
};
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
b4887801
...
...
@@ -41,6 +41,39 @@ struct TileFmhaShape
using
VLayout
=
std
::
conditional_t
<
IsVLayoutRowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
// gemm-0 shapes TODO: naming?
static
constexpr
index_t
Block_M0
=
kM0
;
static
constexpr
index_t
Block_N0
=
kN0
;
static
constexpr
index_t
Block_K0
=
kK0
;
static
constexpr
index_t
BlockWarps_M0
=
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
BlockWarps_N0
=
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
BlockWarps_K0
=
Gemm0BlockWarps
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warps_M0
=
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warps_N0
=
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warps_K0
=
Gemm0WarpTile
::
at
(
number
<
2
>
{});
static_assert
(
Block_M0
%
(
BlockWarps_M0
*
Warps_M0
)
==
0
);
static_assert
(
Block_N0
%
(
BlockWarps_N0
*
Warps_N0
)
==
0
);
static_assert
(
Block_K0
%
(
BlockWarps_K0
*
Warps_K0
)
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
(
BlockWarps_M0
*
Warps_M0
);
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
(
BlockWarps_N0
*
Warps_N0
);
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
(
BlockWarps_K0
*
Warps_K0
);
static
constexpr
index_t
Block_M1
=
kM0
;
static
constexpr
index_t
Block_N1
=
kN1
;
static
constexpr
index_t
Block_K1
=
kK1
;
static
constexpr
index_t
BlockWarps_M1
=
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
BlockWarps_N1
=
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
BlockWarps_K1
=
Gemm1BlockWarps
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warps_M1
=
Gemm1WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warps_N1
=
Gemm1WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warps_K1
=
Gemm1WarpTile
::
at
(
number
<
2
>
{});
static_assert
(
Block_M1
%
(
BlockWarps_M1
*
Warps_M1
)
==
0
);
static_assert
(
Block_N1
%
(
BlockWarps_N1
*
Warps_N1
)
==
0
);
static_assert
(
Block_K1
%
(
BlockWarps_K1
*
Warps_K1
)
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
(
BlockWarps_M1
*
Warps_M1
);
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
(
BlockWarps_N1
*
Warps_N1
);
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
(
BlockWarps_K1
*
Warps_K1
);
};
template
<
typename
BlockTile_
,
// sequence<...
...
...
include/ck_tile/ops/gemm.hpp
View file @
b4887801
...
...
@@ -21,6 +21,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_utils.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
0 → 100644
View file @
b4887801
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
AccWarpDescEnc
,
typename
BlockTile
,
// seq<M, N>
typename
BlockWarps
,
typename
WarpTile
>
CK_TILE_DEVICE_HOST
constexpr
auto
make_block_gemm_acc_enc
()
{
constexpr
index_t
Block_M
=
BlockTile
::
at
(
number
<
0
>
{});
constexpr
index_t
Block_N
=
BlockTile
::
at
(
number
<
1
>
{});
constexpr
index_t
BlockWarps_M
=
BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
BlockWarps_N
=
BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
Warp_M
=
WarpTile
::
at
(
number
<
0
>
{});
constexpr
index_t
Warp_N
=
WarpTile
::
at
(
number
<
1
>
{});
constexpr
index_t
Repeat_M
=
Block_M
/
(
BlockWarps_M
*
Warp_M
);
constexpr
index_t
Repeat_N
=
Block_N
/
(
BlockWarps_N
*
Warp_N
);
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
BlockWarps_M
>
,
sequence
<
Repeat_N
,
BlockWarps_N
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
AccWarpDescEnc
{});
return
c_block_dstr_encode
;
}
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment