Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c8c016dd
Commit
c8c016dd
authored
Dec 13, 2024
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8
parents
e8ca3daf
4e731776
Changes
399
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4650 additions
and
110 deletions
+4650
-110
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+421
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+125
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
...e/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
+33
-0
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+77
-6
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp
+651
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+831
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+354
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+46
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+48
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+5
-0
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+652
-0
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+258
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+73
-0
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+36
-0
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+310
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+111
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+383
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+233
-103
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+1
-1
No files found.
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace
ck_tile
{
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct
FusedMoeGemmHostArgs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
// [max_num_tokens_padded]
const
void
*
sorted_weight_ptr
;
// [max_num_tokens_padded]
const
void
*
sorted_expert_ids_ptr
;
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
void
*
num_sorted_tiles_ptr
;
// [1]
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
struct
FusedMoeGemmKernel
{
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
static
constexpr
index_t
BlockSize_
=
BlockShape
::
BlockSize
;
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
DDataType
=
typename
Pipeline
::
Problem
::
DDataType
;
using
AccDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
ODataType
=
typename
Pipeline
::
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Pipeline
::
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Pipeline
::
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Pipeline
::
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Pipeline
::
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Pipeline
::
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Pipeline
::
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
UseUK
=
true
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
S_
=
BlockShape
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
ADataType
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
GDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
GDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M0
)
+
"x"
+
_TS_
(
S_
::
Block_N0
)
+
"x"
+
_TS_
(
S_
::
Block_K0
)
+
"x"
+
_TS_
(
S_
::
Block_N1
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_K0
)
+
"_"
+
_TS_
(
S_
::
Warp_M0
)
+
"x"
+
_TS_
(
S_
::
Warp_N0
)
+
"x"
+
_TS_
(
S_
::
Warp_K0
)
+
"_"
+
_SS_
(
Pipeline
::
name
);
#undef _SS_
#undef _TS_
// clang-format on
}
struct
FusedMoeGemmKargs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeGemmKargs
;
using
Hargs
=
FusedMoeGemmHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
// TODO: hargs/kargs not guranteed to be the same
return
bit_cast
<
Kargs
>
(
hargs
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
if
constexpr
(
UseUK
)
{
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
num_sorted_tiles
=
num_sorted_tiles
/
BlockShape
::
Block_M0
;
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
// if(threadIdx.x == 0)
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
Pipeline
{}(
kargs
,
smem
,
sorted_tile_id
,
intermediate_tile_id
);
}
else
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
a_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentA
>
{},
number
<
1
>
{});
// gather is here use indexing transform
const
auto
a_gather_view_
=
transform_tensor_view
(
a_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
return
a_window_
;
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
});
return
g_window_
;
}();
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_nr
*
BlockShape
::
Block_W1
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
});
return
d_window_
;
}();
auto
o_window
=
[
&
]()
{
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
);
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
// gather is here
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
{
0
,
0
});
return
o_window_
;
}();
// do compute yeah
Pipeline
{}(
a_window
,
g_window
,
d_window
,
o_window
,
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
,
kargs
.
stride_token
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template
<
typename
BlockTile_0_
,
typename
WarpPerBlock_0_
,
typename
WarpTile_0_
,
typename
BlockTile_1_
,
typename
WarpPerBlock_1_
,
typename
WarpTile_1_
>
struct
FusedMoeGemmShape
{
using
BlockTile_0
=
remove_cvref_t
<
BlockTile_0_
>
;
using
WarpPerBlock_0
=
remove_cvref_t
<
WarpPerBlock_0_
>
;
using
WarpTile_0
=
remove_cvref_t
<
WarpTile_0_
>
;
using
BlockTile_1
=
remove_cvref_t
<
BlockTile_1_
>
;
using
WarpPerBlock_1
=
remove_cvref_t
<
WarpPerBlock_1_
>
;
using
WarpTile_1
=
remove_cvref_t
<
WarpTile_1_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
WarpPerBlock_0
{},
multiplies
{},
number
<
1
>
{});
// TODO: we don't support half warps aound to 1 warp here
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
static
constexpr
index_t
ThreadPerBlock_K0
=
Warp_K0
*
WarpPerBlock_K0
;
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_M1
=
WarpTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N1
=
WarpTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K1
=
WarpTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
static
constexpr
index_t
ThreadPerBlock_K1
=
Warp_K1
*
WarpPerBlock_K1
;
static_assert
(
Block_M1
%
ThreadPerBlock_M1
==
0
);
static_assert
(
Block_N1
%
ThreadPerBlock_N1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
// some assert
static_assert
(
Block_M0
==
Block_M1
);
static_assert
(
Block_N0
==
Block_K1
||
(
Block_N0
/
2
)
==
Block_K1
);
// Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static
constexpr
index_t
Block_W0
=
Warp_N0
*
Warp_K0
;
static
constexpr
index_t
Block_Nr0
=
Block_N0
/
Warp_N0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static_assert
(
Block_W0
==
Block_W1
);
// static_assert(Block_Nr0 == Block_Kr1);
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
template
<
typename
BlockShape_
>
struct
FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using
BlockShape
=
ck_tile
::
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
const
char
*
name
=
"lin"
;
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*intermediate_size*/
)
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
intermediate_size
)
{
// TODO: this may need tuning
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
BlockShape
::
Block_M0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
BlockShape
::
Block_N0
);
return
dim3
(
ns
,
ms
,
1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
c8c016dd
...
...
@@ -12,20 +12,77 @@
namespace
ck_tile
{
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
const
void
*
p_topk_ids
;
// [token, topk]
const
void
*
p_weights
;
// [token, topk]
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
// we fused the setzero of output of fused-moe buffer
// set this pointer to nullptr will skip this operation
void
*
p_moe_buf
;
index_t
tokens
;
index_t
unit_size
;
index_t
unit_size
;
// this is the M_a of fused-moe kernel
index_t
num_experts
;
index_t
topk
;
index_t
moe_buf_bytes
;
index_t
moe_buf_bytes
;
// byte size of p_moe_buf
};
template
<
typename
Problem_
>
...
...
@@ -183,7 +240,13 @@ struct MoeSortingKernel
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
p_sorted_token_ids
[
rank_post_pad
]
=
MOE_SORTING_MOCK_ID
(
curr_token_id
,
curr_topk_id
);
#else
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
}
...
...
@@ -195,7 +258,12 @@ struct MoeSortingKernel
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
}
...
...
@@ -229,4 +297,7 @@ struct MoeSortingKernel
smem
);
}
};
#undef MOE_SORTING_MOCK_ID
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_FlatmmEx
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"fused_moe_flatmm"
;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
return
Policy
::
template
GetSmemSize_A
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
OWindow
&
o_window_
,
TopkWeightDataType
/*topk_weight*/
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
intermediate_size
)
{
_Pragma
(
"clang diagnostic push"
)
_Pragma
(
"clang diagnostic ignored
\"
-Wc++20-extensions
\"
"
);
constexpr
auto
NEG1
=
number
<-
1
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
TRUE
=
bool_constant
<
true
>
{};
constexpr
auto
FALSE
=
bool_constant
<
false
>
{};
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
ADataType
*
smem_1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
char
*>
(
smem
)
+
Policy
::
template
GetSmemSize_A
<
Problem
>());
auto
g_view
=
g_window_
.
get_bottom_tensor_view
();
auto
u_view
=
[
&
]()
{
if
constexpr
(
IsGateOnly
)
{
return
g_view
;
}
else
{
index_t
nr_0
=
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
hidden_size
/
BlockShape
::
Block_Kr0
;
const
GDataType
*
g_ptr
=
g_window_
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
const
GDataType
*
u_ptr
=
g_ptr
+
(
nr_0
/
2
)
*
kr_0
*
number
<
BlockShape
::
Block_W0
>
{};
const
auto
u_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
u_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
u_view_1_
=
pad_tensor_view
(
u_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
return
u_view_1_
;
}
}();
auto
a_win
=
make_tile_window_linear
(
a_window_
,
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
g_win
=
make_tile_window_linear
(
g_window_
,
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
auto
d_win
=
make_tile_window_linear
(
d_window_
,
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
auto
o_win
=
make_tile_window_linear
(
o_window_
,
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_win
));
using
d_thread_type
=
decltype
(
load_tile
(
d_win
));
using
WarpGemm0
=
decltype
(
Policy
::
template
GetWarpGemm0
<
Problem
>());
using
WarpGemm1
=
decltype
(
Policy
::
template
GetWarpGemm1
<
Problem
>());
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// issues_warps_lanes
auto
a_sst_win0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sst_win1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_sld_win0
=
[
&
]()
{
using
WG
=
WarpGemm0
;
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
sequence
<
BlockShape
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
// m*k
auto
a_sld_win1
=
[
&
]()
{
using
WG
=
WarpGemm0
;
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
sequence
<
BlockShape
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
auto
bridge_sst_win
=
[
&
]()
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
Policy
::
template
MakeBridgeLdsStoreDesc
<
Problem
>()),
Policy
::
template
MakeBridgeLdsStoreDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
auto
bridge_sld_win
=
[
&
]()
{
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
Policy
::
template
MakeBridgeLdsLoadDesc
<
Problem
>()),
Policy
::
template
MakeBridgeLdsLoadDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeYTileDistribution
<
Problem
>());
}();
// also OK with C array, 2 register buffer
statically_indexed_array
<
g_thread_type
,
2
>
gs
;
constexpr
auto
issues_a
=
number
<
a_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_g
=
number
<
g_win
.
get_num_of_access
()
>
{};
// constexpr auto issues_d = number<d_win.get_num_of_access()>{};
// constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr
auto
issues_gemm0
=
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
*
warp_gemm_0
.
get_num_of_access
()
>
{};
constexpr
auto
issues_gemm1
=
number
<
BlockShape
::
Repeat_M1
*
BlockShape
::
Repeat_N1
*
BlockShape
::
Repeat_K1
*
warp_gemm_1
.
get_num_of_access
()
>
{};
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const
index_t
num_blocks_k0
=
(
hidden_size
+
BlockShape
::
Block_K0
-
1
)
/
BlockShape
::
Block_K0
;
const
index_t
num_blocks_n1
=
(
hidden_size
+
BlockShape
::
Block_N1
-
1
)
/
BlockShape
::
Block_N1
;
using
a_thread_type
=
decltype
(
load_tile
(
a_sld_win0
));
statically_indexed_array
<
a_thread_type
,
2
>
as
;
auto
gld_a
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
a_store_
,
auto
i_access
,
PreNop
=
{})
{
async_load_tile_raw
(
a_store_
,
a_win
,
i_access
,
PreNop
{});
};
auto
move_a
=
[
&
]()
{
move_tile_window
(
a_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_K0
>
{}});
};
auto
sld_a
=
[
&
](
auto
&
a_
,
auto
&
win_
,
auto
i_access
)
{
load_tile_raw
(
a_
,
win_
,
i_access
);
};
auto
gld_g
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
g_
,
auto
i_access
,
PreNop
=
{})
{
if
constexpr
(
IsGateOnly
)
{
// TODO: hack!
if
constexpr
(
i_access
.
value
==
0
)
{
g_win
.
bottom_tensor_view_
=
g_view
;
}
else
if
constexpr
(
i_access
.
value
==
issues_g
/
2
)
{
g_win
.
bottom_tensor_view_
=
u_view
;
}
}
load_tile_raw
(
g_
,
g_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_g
=
[
&
]()
{
move_tile_window
(
g_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
0
>
{}});
};
statically_indexed_array
<
d_thread_type
,
2
>
ds
;
auto
gld_d
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
d_
,
auto
i_access
,
PreNop
=
{})
{
load_tile_raw
(
d_
,
d_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_d
=
[
&
]()
{
// d move along gemm-n
move_tile_window
(
d_win
,
{
number
<
BlockShape
::
Block_N1
>
{},
number
<
0
>
{}});
};
auto
atomic_add_o
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
o_
,
auto
i_access
,
PreNop
=
{})
{
update_tile_raw
(
o_win
,
o_
,
i_access
,
TRUE
,
PreNop
{});
};
auto
acc_0
=
Policy
::
template
MakeCBlockTile_Gemm0
<
Problem
>();
auto
acc_1s
=
generate_tuple
(
[
&
](
auto
)
{
return
Policy
::
template
MakeCBlockTile_Gemm1
<
Problem
>();
},
number
<
2
>
{});
// clang-format off
auto
gemm_0
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm_0
)
>
;
constexpr
auto
repeat_sub
=
WarpGemm
::
get_num_of_access
();
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_sub
=
i_access
%
repeat_sub
;
constexpr
auto
i_k
=
(
i_access
/
repeat_sub
)
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
WarpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
AWarpTensor
w_a
;
w_a
.
get_thread_buffer
()
=
t_a
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_k
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
BWarpTensor
w_b
;
w_b
.
get_thread_buffer
()
=
t_b
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_n
,
i_k
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
CWarpTensor
w_c
;
w_c
.
get_thread_buffer
()
=
t_c
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
warp_gemm_0
(
w_c
,
w_a
,
w_b
,
number
<
i_sub
>
{},
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
w_c
.
get_thread_buffer
());
};
// clang-format on
// clang-format off
auto
gemm_1
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm_1
)
>
;
constexpr
auto
repeat_sub
=
WarpGemm
::
get_num_of_access
();
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_sub
=
i_access
%
repeat_sub
;
constexpr
auto
i_k
=
(
i_access
/
repeat_sub
)
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
WarpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
AWarpTensor
w_a
;
w_a
.
get_thread_buffer
()
=
t_a
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_k
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
BWarpTensor
w_b
;
w_b
.
get_thread_buffer
()
=
t_b
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_n
,
i_k
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
CWarpTensor
w_c
;
w_c
.
get_thread_buffer
()
=
t_c
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
warp_gemm_1
(
w_c
,
w_a
,
w_b
,
number
<
i_sub
>
{},
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
w_c
.
get_thread_buffer
());
};
// clang-format on
_Pragma
(
"clang diagnostic pop"
);
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto
pipeline_gemm0
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_sld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
SLD_A
)
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
NEXT_SCI
(
c_sld_a_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_A
)
gld_a
(
a_sst_win0
,
number
<
NEXT_SCI
(
c_gld_a_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I0
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
constexpr
auto
c_sld_a_1
=
MAKE_SC
();
constexpr
auto
c_gld_a_1
=
MAKE_SC
();
constexpr
auto
c_gld_b_1
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
SLD_A
)
sld_a
(
as
[
I0
],
a_sld_win0
,
number
<
NEXT_SCI
(
c_sld_a_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_A
)
gld_a
(
a_sst_win1
,
number
<
NEXT_SCI
(
c_gld_a_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_1
,
i_issue
)
>
{});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
};
auto
pipeline_gemm0_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
block_sync_load_raw
(
issues_g
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
);
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
constexpr
auto
last_nop
=
[
&
]()
{
if
constexpr
(
i_issue
==
(
total_loops
-
1
))
return
TRUE
;
else
return
FALSE
;
}();
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
,
last_nop
);
// last gemm has nop
});
};
auto
y
=
Policy
::
template
MakeYBlockTile
<
Problem
>();
auto
pipeline_bridge
=
[
&
]()
{
// cast to Y data
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
store_tile
(
bridge_sst_win
,
y_pre
);
clear_tile
(
acc_1s
(
I0
));
// wave_barrier();
load_tile
(
y
,
bridge_sld_win
);
clear_tile
(
acc_1s
(
I1
));
};
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto
pipeline_gemm1
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
constexpr
auto
c_gst_o_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_1
=
MAKE_SC
();
constexpr
auto
c_gst_o_1
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I0
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GST_O
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_0
,
i_issue
)
>
{});
}
});
move_d
();
// move_o();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GST_O
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_1
,
i_issue
)
>
{});
}
});
move_d
();
};
auto
pipeline_gemm1_head
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
move_d
();
};
auto
pipeline_gemm1_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gst_o_0
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GST_O
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_0
,
i_issue
)
>
{});
}
});
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
NEG1
);
}
};
// start of pipeline
// clang-format off
gld_a
(
a_sst_win0
,
NEG1
,
TRUE
);
gld_g
(
gs
[
I0
],
NEG1
,
TRUE
);
move_a
();
move_g
();
clear_tile
(
acc_0
);
// preload for next round
gld_a
(
a_sst_win1
,
NEG1
);
gld_g
(
gs
[
I1
],
NEG1
);
// make sure a,g loaded
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
// we manually unroll double buffer inside hot loop
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
index_t
i_0
=
0
;
// (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while
(
i_0
++
<
iters_0
)
{
pipeline_gemm0
();
}
pipeline_gemm0_tail
();
pipeline_bridge
();
const
index_t
iters_1
=
(
num_blocks_n1
-
2
)
/
2
;
index_t
i_1
=
0
;
// (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head
();
while
(
i_1
++
<
iters_1
)
{
pipeline_gemm1
();
}
pipeline_gemm1_tail
();
// clang-format on
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
struct
FusedMoeGemmPipelineFlatmmPolicy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO: always 1 dword
return
1
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_A
()
{
// using async
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
();
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
ADataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_G
()
{
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
GDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_D
()
{
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
DDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_O
()
{
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
1
)
{
// pack fp16/bf16 atomic
static_assert
(
sizeof
(
typename
Problem
::
ODataType
)
==
2
);
return
2
;
}
else
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
2
)
{
// fp32 atomic
return
1
;
}
else
{
return
16
/
sizeof
(
typename
Problem
::
ODataType
);
}
}
template
<
typename
DataType_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack
()
{
// TODO: this is for 3d layout
return
16
/
sizeof
(
remove_cvref_t
<
DataType_
>
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_A
()
{
return
GetSmemKPack
<
typename
Problem
::
ADataType
>
();
}
// used for bridge LDS shuffle
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_Y
()
{
// TODO: this should match mfma layout
return
16
/
sizeof
(
typename
Problem
::
YDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
constexpr
auto
a_sld_desc
=
MakeLdsLoadDesc_A
<
Problem
>
();
constexpr
auto
a_sst_desc
=
MakeLdsStoreDesc_A
<
Problem
>
();
static_assert
(
a_sld_desc
.
get_element_space_size
()
==
a_sst_desc
.
get_element_space_size
());
return
a_sld_desc
.
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
{
constexpr
auto
bridge_sld_desc
=
MakeBridgeLdsLoadDesc
<
Problem
>
();
constexpr
auto
bridge_sst_desc
=
MakeBridgeLdsStoreDesc
<
Problem
>
();
static_assert
(
bridge_sld_desc
.
get_element_space_size
()
==
bridge_sst_desc
.
get_element_space_size
());
return
bridge_sld_desc
.
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
a_lds
=
GetSmemSize_A
<
Problem
>
();
constexpr
index_t
bridge_lds
=
GetSmemSize_Bridge
<
Problem
>
();
return
max
(
a_lds
,
bridge_lds
);
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"not not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
// optimized version for async, not same as simple MXK dist(pay attention!!)
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK_Async
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<=
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"do not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
// NOTE: no swap, but hard to avoid LDS bank conflict
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
// NOTE: swapped for LDS load bank conflict free
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
// distribution
tuple
<
sequence
<
M_rep
,
M_lan
,
M_wav
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
index_t
WarpPerBlock_N_
,
index_t
WarpPerBlock_K_
,
index_t
Repeat_N_
,
index_t
Repeat_K_
,
index_t
WarpSize_
,
index_t
Alignment_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_Nr_Kr_W
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Repeat_N_
,
WarpPerBlock_N_
>
,
sequence
<
Repeat_K_
,
WarpPerBlock_K_
>
,
sequence
<
WarpSize_
,
Alignment_
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
{
constexpr
index_t
Block_M_
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K_
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
NumWarps_
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
Alignment_
=
GetAlignment_A
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
Block_M_
,
Block_K_
,
NumWarps_
,
Alignment_
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N0
,
S_
::
WarpPerBlock_K0
,
S_
::
Repeat_N0
,
/// hidden_radio_0,
S_
::
Repeat_K0
,
get_warp_size
(),
GetAlignment_G
<
Problem
>
()
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N1
,
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_N1
,
S_
::
Repeat_K1
,
get_warp_size
(),
GetAlignment_D
<
Problem
>
()
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_O
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
// using CDataType = typename WarpGemm::CDataType;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
KVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
KVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>=
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsLoadDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
0
;
// pad between warps
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
KPad
>
{},
number
<
1
>
{}),
number
<
KVector
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
0
;
// KVector; // pad between warps
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
KPad
>
{},
number
<
1
>
{}),
number
<
KVector
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreForUKDesc
()
{
constexpr
index_t
WarpPerBlock_N
=
Problem
::
BlockShape
::
WarpPerBlock_N0
;
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N0
;
constexpr
index_t
Repeat_M
=
Problem
::
BlockShape
::
Repeat_M0
;
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
KPack
=
kABKPerLane
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Repeat_M
>
{},
// m
number
<
Repeat_N
>
{},
// n
number
<
WarpPerBlock_N
>
{},
// n
number
<
kABKLane
>
{},
// n
number
<
kAMLane
>
{},
// m
number
<
KPack
>
{}),
// n
make_tuple
(
number
<
Repeat_N
*
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// m
number
<
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kAMLane
*
KPack
>
{},
// n
number
<
KPack
>
{},
// m
number
<
1
>
{}),
// n
number
<
KPack
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
desc
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_tuple
(
number
<
Repeat_N
>
{},
number
<
WarpPerBlock_N
>
{},
number
<
kABKLane
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
4
>
{},
sequence
<
1
,
2
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
{
using
S_
=
typename
Problem
::
BlockShape
;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_avv
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
wg_ctrl
>
,
2
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
32
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
<
wg_ctrl
>
,
2
>>
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSequencer_0
()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 0
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 1
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 2
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 3
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 4
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 5
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 6
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 7
return
seq_all
;
// clang-format on
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 0
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 1
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 2
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
>
{};
// 3
return
seq_all
;
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSequencer_1
()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M1
==
32
&&
S_
::
Warp_N1
==
32
&&
S_
::
Warp_K1
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 0
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 1
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 2
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 3
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 4
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 5
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 6
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 7
return
seq_all
;
// clang-format on
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M1
==
32
&&
S_
::
Warp_N1
==
32
&&
S_
::
Warp_K1
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 0
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 1
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 2
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 3
return
seq_all
;
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_avv
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
wg_ctrl
>
,
2
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
32
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
<
wg_ctrl
>
,
2
>>
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeCBlockTile_Gemm0
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
using
CDataType
=
typename
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M0
,
S_
::
WarpPerBlock_M0
>
,
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeCBlockTile_Gemm1
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
CDataType
=
typename
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// this is used as A matrix for 2nd gemm
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeYTileDistribution
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
// TODO: all waves a along different N, but same M
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
>>
,
tuple
<
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
y_block_dstr
=
make_static_tile_distribution
(
y_block_dstr_encode
);
return
y_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeYBlockTile
()
{
constexpr
auto
y_block_dstr
=
MakeYTileDistribution
<
Problem
>
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
YDataType
>
(
y_block_dstr
);
return
y_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetUK_0
()
{
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
Flatmm_32x512x128_1x4x1_16x16x32_BF16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
fp16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
Flatmm_32x512x128_1x4x1_16x16x32_FP16
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetUK_1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
FlatmmSn_32x128x512_1x4x1_16x16x32_BF16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16
{};
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_FlatmmUk
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"flatmm_uk"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
CK_TILE_DEVICE
constexpr
auto
GetNumRowCoords_A
()
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
return
MRepeat
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
row_ids
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetWeightScale
(
const
ROW_COORDS
coords
,
const
TopkWeightDataType
*
sorted_weight_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
w
;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
constexpr
index_t
MLanes
=
BlockShape
::
Warp_M1
;
constexpr
index_t
Repeat_M
=
BlockShape
::
Repeat_M1
;
auto
base_coord
=
threadIdx
.
x
%
MLanes
+
base_offset
;
array
<
index_t
,
Repeat_M
>
coords
;
static_for
<
0
,
Repeat_M
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLanes
;
});
return
coords
;
}
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
sorted_tile_id
,
index_t
intermediate_tile_id
)
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
index_t
interm_idx_kr1
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Kr1
);
// intermediate_tile_id * Block_N / (N in W)
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
}();
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
const
auto
d_win
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_kr1
*
BlockShape
::
Block_W1
;
// note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_window_
=
make_tile_window_linear_raw
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
d_window_
;
}();
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// wg |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
4
;
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
constexpr
index_t
num_offsets_
=
Nr_
*
Kr0_
;
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
shared_intermediate_size_1
*
Nl_
;
// Kr0_ * Kr1_ * W_;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
constexpr
auto
i_kr0_
=
number
<
i
/
Nr_
>
{};
return
i_nr_
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
i_kr0_
*
Kr1_
*
W_
+
base_os_
;
},
number
<
num_offsets_
>
{});
}();
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
row_ids_a
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
{
0
,
0
},
dist_
);
}();
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
g_res
,
g_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
,
auto
idx1
)
{
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
(
idx0
)
=
v_
.
x
;
acc_0
(
idx1
)
=
v_
.
y
;
},
sequence
<
1
,
2
>
{});
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
block_sync_lds
();
store_tile
(
bridge_sst_win
,
y_pre
);
block_sync_lds
();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
d_coords
,
o_res
,
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_Nr1
*
kr_1
*
BlockShape
::
Block_W1
,
// along N
BlockShape
::
Block_N1
);
// along N
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// TODO: alow 2 gemm have different type
template
<
typename
ADataType_
,
typename
GDataType_
,
typename
DDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
GScaleDataType_
,
typename
DScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
GateActivation_
,
// = ck_tile::element_wise::Silu,
typename
BlockShape_
,
// shoule be FusedMoeGemmShape
typename
Traits_
>
struct
FusedMoeGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GDataType
=
remove_cvref_t
<
GDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
AScaleDataType
=
remove_cvref_t
<
AScaleDataType_
>
;
using
GScaleDataType
=
remove_cvref_t
<
GScaleDataType_
>
;
using
DScaleDataType
=
remove_cvref_t
<
DScaleDataType_
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
YSmoothScaleDataType_
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
TopkWeightDataType_
>
;
using
IndexDataType
=
remove_cvref_t
<
IndexDataType_
>
;
// the input for next gemm should have same time as
using
YDataType
=
ADataType
;
using
GateActivation
=
remove_cvref_t
<
GateActivation_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
class
FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute
=
0
,
b_nr_kr_kw_nw_kv
=
1
,
// 0,1,3,4,2,5
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
,
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
FusedMoeGemmWeightPermuteEnum
PermuteEnum
=
PermuteEnum_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
// Note: this need to be a bit mask
enum
class
FusedMoeGemmPipelineSequencerEnum
{
SLD_A
=
1
<<
0
,
// shared load a
SLD_B
=
1
<<
1
,
GLD_A
=
1
<<
2
,
// global load a
GLD_B
=
1
<<
3
,
SST_A
=
1
<<
4
,
// shared store a
SST_B
=
1
<<
5
,
GST_O
=
1
<<
6
,
// global store out
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
c8c016dd
...
...
@@ -22,8 +22,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
>
struct
BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
static_assert
(
MWarp
==
BlockGemmShape
::
BlockWarps
::
at
(
I0
{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"
);
static_assert
(
NWarp
==
BlockGemmShape
::
BlockWarps
::
at
(
I1
{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"
);
static_assert
(
WarpGemm
::
kM
==
BlockGemmShape
::
WarpTile
::
at
(
I0
{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!"
);
static_assert
(
WarpGemm
::
kN
==
BlockGemmShape
::
WarpTile
::
at
(
I1
{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!"
);
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static_assert
(
MIterPerWarp
*
MWarp
*
WarpGemm
::
kM
==
MPerBlock
,
"Error! Warps should cover all Block tile!"
);
static_assert
(
NIterPerWarp
*
NWarp
*
WarpGemm
::
kN
==
NPerBlock
,
"Error! Warps should cover all Block tile!"
);
static
constexpr
index_t
MPerBlockPerIter
=
MWarp
*
WarpGemm
::
kM
;
static
constexpr
index_t
NPerBlockPerIter
=
NWarp
*
WarpGemm
::
kN
;
static
constexpr
index_t
KPerBlockPerIter
=
WarpGemm
::
kK
;
using
AWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}))
>
;
using
BWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}))
>
;
using
AWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
AWarpTileDistr
{}))
>
;
using
BWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BWarpTileDistr
{}))
>
;
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
WarpGemm
::
kK
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
public:
using
Traits
=
GemmTraits_
<
Problem_
,
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
using
WarpGemm
=
remove_cvref_t
<
typename
Traits
::
WarpGemm
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
static
constexpr
auto
Scheduler
=
Traits
::
Scheduler
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
private:
template
<
GemmPipelineScheduler
Scheduler
,
typename
GemmTraits
>
struct
BlockGemmImpl
{
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Default
,
GemmTraits
>
{
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I1
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kM
>
{},
number
<
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WarpGemm
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kN
>
{},
number
<
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WarpGemm
::
kN
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
// TODO: I don't have to move 0,0 window!
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
const
auto
a_warp_tile
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
const
auto
b_warp_tile
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WarpGemm
{}(
c_warp_tensor
,
a_warp_tile
,
b_warp_tile
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Intrawave
,
GemmTraits
>
{
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
AWarpTile
,
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_tiles_
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
BWarpTile
,
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_tiles_
;
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I1
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kM
>
{},
number
<
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WarpGemm
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kN
>
{},
number
<
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WarpGemm
::
kN
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
// TODO: I don't have to move 0,0 window!
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
});
});
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
[[
maybe_unused
]]
const
ASmemBlockWindow
&
a_block_window
,
[[
maybe_unused
]]
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WarpGemm
{}(
c_warp_tensor
,
a_warp_tiles_
[
mIter
][
kIter
],
b_warp_tiles_
[
nIter
][
kIter
]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Interwave
,
GemmTraits
>
{
static
constexpr
index_t
KPerThread
=
GemmTraits
::
KPerThread
;
static
constexpr
index_t
NumMacClusters
=
GemmTraits
::
InterWaveSchedulingMacClusters
;
static
constexpr
index_t
KPerInnerLoop
=
ck_tile
::
max
(
KPerThread
/
NumMacClusters
,
GemmTraits
::
KPack
);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
KInnerLoopIter
=
KPerInnerLoop
/
GemmTraits
::
KPack
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
AWarpTile
,
KInnerLoopIter
>
,
MIterPerWarp
>
a_warp_tiles_
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
BWarpTile
,
KInnerLoopIter
>
,
NIterPerWarp
>
b_warp_tiles_
;
template
<
index_t
KIdx
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I1
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kM
>
{},
number
<
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WarpGemm
::
kM
,
KIdx
*
KPerInnerLoop
},
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
KInnerLoopIter
>
,
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kN
>
{},
number
<
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WarpGemm
::
kN
,
KIdx
*
KPerInnerLoop
},
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
KInnerLoopIter
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
});
});
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
kIter
)
{
LocalPrefetch
<
kIter
.
value
>
(
a_block_window
,
b_block_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
kIter
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kInnerIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if
constexpr
(
kIter
.
value
==
KRepeat
-
1
&&
kInnerIter
.
value
==
KInnerLoopIter
-
1
&&
mIter
.
value
==
MIterPerWarp
-
1
&&
nIter
.
value
==
NIterPerWarp
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// warp GEMM
WarpGemm
{}(
c_warp_tensor
,
a_warp_tiles_
[
mIter
][
kInnerIter
],
b_warp_tiles_
[
nIter
][
kInnerIter
]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
if
constexpr
(
kInnerIter
.
value
==
0
&&
mIter
.
value
==
0
&&
nIter
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
};
public:
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
block_gemm_impl_
.
LocalPrefetch
(
a_block_window
,
b_block_window
);
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
block_gemm_impl_
(
c_block_tensor
,
a_block_window
,
b_block_window
);
}
// C = A * B
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
auto
c_block_tensor
=
MakeCBlockTile
();
block_gemm_impl_
(
c_block_tensor
,
a_block_window
,
b_block_window
);
return
c_block_tensor
;
}
private:
BlockGemmImpl
<
Scheduler
,
Traits
>
block_gemm_impl_
{};
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
struct
BatchedGemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
struct
BatchedGemmKargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
using
Kargs
=
BatchedGemmKargs
;
using
Hargs
=
BatchedGemmHostArgs
;
__host__
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
return
TilePartitioner
::
GridSize
(
h
.
M
,
h
.
N
,
h
.
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKargs
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
a_ptr
=
h
.
a_ptr
;
k
.
b_ptr
=
h
.
b_ptr
;
k
.
c_ptr
=
h
.
c_ptr
;
k
.
M
=
h
.
M
;
k
.
N
=
h
.
N
;
k
.
K
=
h
.
K
;
k
.
stride_A
=
h
.
stride_A
;
k
.
stride_B
=
h
.
stride_B
;
k
.
stride_C
=
h
.
stride_C
;
k
.
batch_stride_A
=
h
.
batch_stride_A
;
k
.
batch_stride_B
=
h
.
batch_stride_B
;
k
.
batch_stride_C
=
h
.
batch_stride_C
;
k
.
batch_count
=
h
.
batch_count
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
// clang-format on
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
c8c016dd
...
...
@@ -66,6 +66,79 @@ struct GemmKernel
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmCommonKargs
&
kargs
)
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
kM
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
kN
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
kN
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeC
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
kM
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeC
!=
0
)
{
return
false
;
}
}
return
true
;
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
c8c016dd
...
...
@@ -35,4 +35,40 @@ struct GemmTilePartitioner
return
make_tuple
(
iM
,
iN
);
}
};
template
<
typename
BlockGemmShape_
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
{
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
CK_TILE_DEVICE
auto
operator
()(
index_t
blockOffset
,
index_t
NBlockSize
)
{
index_t
iM
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
/
GetNBlock
(
NBlockSize
)
*
MPerBlock
);
index_t
iN
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
%
GetNBlock
(
NBlockSize
)
*
NPerBlock
);
return
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host.hpp"
namespace
ck_tile
{
struct
GroupedGemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GroupedGemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
struct
GemmTransKernelArg
{
GroupedGemmHostArgs
group_karg
;
ck_tile
::
index_t
block_start
;
ck_tile
::
index_t
block_end
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
GroupedGemmHostArgs
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
:
group_karg
{
karg
},
block_start
{
bl_start
},
block_end
{
bl_end
}
{
}
};
__host__
static
size_t
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
{
return
gemm_descs
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
using
Hargs
=
GroupedGemmHostArgs
;
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
{
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
{
const
auto
dim3
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
dim3
.
x
*
dim3
.
y
*
1
;
}
return
dim3
(
grid_size
,
1
,
1
);
}
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
{
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
group_count
=
ck_tile
::
type_convert
<
ck_tile
::
index_t
>
(
gemm_descs
.
size
());
index_t
grid_size
=
0
;
gemm_kernel_args_
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M
;
const
index_t
N
=
gemm_descs
[
i
].
N
;
const
index_t
K
=
gemm_descs
[
i
].
K
;
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
*
1
*
1
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
auto
karg
=
GroupedGemmHostArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
return
gemm_kernel_args_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
Run
(
const
Hargs
&
kargs
,
const
index_t
block_start
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}(
block_start
,
kargs
.
N
);
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
int
group_count
)
const
{
const
index_t
block_id
=
ck_tile
::
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmTransKernelArg
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end
))
&&
left
<=
right
)
{
if
(
block_id
<
gemm_desc_ptr
[
group_id
].
block_start
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
Run
(
gemm_desc_ptr
[
group_id
].
group_karg
,
gemm_desc_ptr
[
group_id
].
block_start
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
Problem
,
typename
Policy
>
struct
GemmPipelineAgBgCrImplBase
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
return
make_tuple
(
std
::
move
(
a_lds_block
),
std
::
move
(
b_lds_block
));
}
template
<
typename
ADramBlockWindowTmp
,
typename
ALdsTensorView
>
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
ALdsTensorView
&
a_lds_block_view
)
const
{
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
return
make_tuple
(
std
::
move
(
a_copy_dram_window
),
std
::
move
(
a_copy_lds_window
),
std
::
move
(
a_lds_gemm_window
));
}
template
<
typename
BDramBlockWindowTmp
,
typename
BLdsTensorView
>
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BLdsTensorView
&
b_lds_block_view
)
const
{
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
return
make_tuple
(
std
::
move
(
b_copy_dram_window
),
std
::
move
(
b_copy_lds_window
),
std
::
move
(
b_lds_gemm_window
));
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
0 → 100644
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrCompV3
{
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I1
{});
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I2
{});
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto
&&
[
a_lds_block
,
b_lds_block
]
=
Base
::
GetABLdsTensorViews
(
p_smem
);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
&&
[
a_copy_dram_window
,
a_copy_lds_window
,
a_lds_gemm_window
]
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
&&
[
b_copy_dram_window
,
b_copy_lds_window
,
b_lds_gemm_window
]
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
ABlockTile
a_block_tile
;
BBlockTile
b_block_tile
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
c8c016dd
...
...
@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace
ck_tile
{
...
...
@@ -91,6 +92,7 @@ template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1Defaul
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
...
...
@@ -124,47 +127,208 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
using
Base
=
PipelineImplBase
;
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto
ab_lds_blocks
=
Base
::
GetABLdsTensorViews
(
p_smem
);
auto
&
a_lds_block
=
ab_lds_blocks
.
at
(
I0
{});
auto
&
b_lds_block
=
ab_lds_blocks
.
at
(
I1
{});
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto
a_windows
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
auto
&
a_copy_dram_window
=
a_windows
.
at
(
I0
{});
auto
&
a_copy_lds_window
=
a_windows
.
at
(
I1
{});
auto
&
a_lds_gemm_window
=
a_windows
.
at
(
I2
{});
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto
b_windows
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
auto
&
b_copy_dram_window
=
b_windows
.
at
(
I0
{});
auto
&
b_copy_lds_window
=
b_windows
.
at
(
I1
{});
auto
&
b_lds_gemm_window
=
b_windows
.
at
(
I2
{});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Interwave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
...
...
@@ -185,69 +349,41 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A/B tiles in LDS
// With c++20 could simplify to below line.
// Currently get error: captured structured bindings are a C++20 extension
// auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto
ab_lds_blocks
=
Base
::
GetABLdsTensorViews
(
p_smem
);
auto
&
a_lds_block
=
ab_lds_blocks
.
at
(
I0
{});
auto
&
b_lds_block
=
ab_lds_blocks
.
at
(
I1
{});
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A LDS tile for block GEMM
auto
a_windows
=
Base
::
GetAWindows
(
a_dram_block_window_tmp
,
a_lds_block
);
auto
&
a_copy_dram_window
=
a_windows
.
at
(
I0
{});
auto
&
a_copy_lds_window
=
a_windows
.
at
(
I1
{});
auto
&
a_lds_gemm_window
=
a_windows
.
at
(
I2
{});
// B DRAM tile window for load
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_windows
=
Base
::
GetBWindows
(
b_dram_block_window_tmp
,
b_lds_block
);
auto
&
b_copy_dram_window
=
b_windows
.
at
(
I0
{});
auto
&
b_copy_lds_window
=
b_windows
.
at
(
I1
{});
auto
&
b_lds_gemm_window
=
b_windows
.
at
(
I2
{});
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
...
...
@@ -266,20 +402,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
...
...
@@ -290,23 +426,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
block_sync_lds
();
LocalPrefill
(
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
...
...
@@ -317,28 +451,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
c8c016dd
...
...
@@ -11,6 +11,7 @@ namespace ck_tile {
enum
struct
GemmPipelineScheduler
{
Default
,
Intrawave
,
Interwave
,
};
...
...
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{
switch
(
s
)
{
case
ck_tile
::
GemmPipelineScheduler
::
Default
:
os
<<
"Default"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
default:
os
<<
""
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
c8c016dd
...
...
@@ -124,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment