Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6e3c786e
Commit
6e3c786e
authored
Dec 06, 2024
by
Jing Zhang
Browse files
merge develop
parents
1bb510cb
261f1759
Changes
465
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2244 additions
and
146 deletions
+2244
-146
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+354
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+46
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+48
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
...e/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
+39
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
...ude/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
+15
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
...de/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+23
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+16
-5
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
...ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
+237
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+16
-16
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+652
-0
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+258
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+175
-74
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+50
-10
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+310
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+0
-36
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_FlatmmUk
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"flatmm_uk"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
CK_TILE_DEVICE
constexpr
auto
GetNumRowCoords_A
()
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
return
MRepeat
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
row_ids
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetWeightScale
(
const
ROW_COORDS
coords
,
const
TopkWeightDataType
*
sorted_weight_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
w
;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
constexpr
index_t
MLanes
=
BlockShape
::
Warp_M1
;
constexpr
index_t
Repeat_M
=
BlockShape
::
Repeat_M1
;
auto
base_coord
=
threadIdx
.
x
%
MLanes
+
base_offset
;
array
<
index_t
,
Repeat_M
>
coords
;
static_for
<
0
,
Repeat_M
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLanes
;
});
return
coords
;
}
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
sorted_tile_id
,
index_t
intermediate_tile_id
)
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
index_t
interm_idx_kr1
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Kr1
);
// intermediate_tile_id * Block_N / (N in W)
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
}();
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
const
auto
d_win
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_kr1
*
BlockShape
::
Block_W1
;
// note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_window_
=
make_tile_window_linear_raw
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
d_window_
;
}();
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// wg |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
4
;
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
constexpr
index_t
num_offsets_
=
Nr_
*
Kr0_
;
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
shared_intermediate_size_1
*
Nl_
;
// Kr0_ * Kr1_ * W_;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
constexpr
auto
i_kr0_
=
number
<
i
/
Nr_
>
{};
return
i_nr_
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
i_kr0_
*
Kr1_
*
W_
+
base_os_
;
},
number
<
num_offsets_
>
{});
}();
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
row_ids_a
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
{
0
,
0
},
dist_
);
}();
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
g_res
,
g_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
,
auto
idx1
)
{
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
(
idx0
)
=
v_
.
x
;
acc_0
(
idx1
)
=
v_
.
y
;
},
sequence
<
1
,
2
>
{});
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
block_sync_lds
();
store_tile
(
bridge_sst_win
,
y_pre
);
block_sync_lds
();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
d_coords
,
o_res
,
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_Nr1
*
kr_1
*
BlockShape
::
Block_W1
,
// along N
BlockShape
::
Block_N1
);
// along N
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// TODO: alow 2 gemm have different type
template
<
typename
ADataType_
,
typename
GDataType_
,
typename
DDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
GScaleDataType_
,
typename
DScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
GateActivation_
,
// = ck_tile::element_wise::Silu,
typename
BlockShape_
,
// shoule be FusedMoeGemmShape
typename
Traits_
>
struct
FusedMoeGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GDataType
=
remove_cvref_t
<
GDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
AScaleDataType
=
remove_cvref_t
<
AScaleDataType_
>
;
using
GScaleDataType
=
remove_cvref_t
<
GScaleDataType_
>
;
using
DScaleDataType
=
remove_cvref_t
<
DScaleDataType_
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
YSmoothScaleDataType_
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
TopkWeightDataType_
>
;
using
IndexDataType
=
remove_cvref_t
<
IndexDataType_
>
;
// the input for next gemm should have same time as
using
YDataType
=
ADataType
;
using
GateActivation
=
remove_cvref_t
<
GateActivation_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
class
FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute
=
0
,
b_nr_kr_kw_nw_kv
=
1
,
// 0,1,3,4,2,5
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
,
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
FusedMoeGemmWeightPermuteEnum
PermuteEnum
=
PermuteEnum_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
// Note: this need to be a bit mask
enum
class
FusedMoeGemmPipelineSequencerEnum
{
SLD_A
=
1
<<
0
,
// shared load a
SLD_B
=
1
<<
1
,
GLD_A
=
1
<<
2
,
// global load a
GLD_B
=
1
<<
3
,
SST_A
=
1
<<
4
,
// shared store a
SST_B
=
1
<<
5
,
GST_O
=
1
<<
6
,
// global store out
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
// struct MoeSortingPipeline
// {
// // TODO: this kernel only support warp per row
// using Problem = remove_cvref_t<Problem_>;
// using Policy = remove_cvref_t<Policy_>;
// using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* p_sorted_token_ids,
// WeightType* p_sorted_weights,
// index_t* p_sorted_expert_ids,
// index_t* p_total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
// };
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
MoeSortingPolicy
{
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
struct
MoeSortingProblem
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
6e3c786e
...
...
@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
...
...
@@ -21,17 +22,27 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
6e3c786e
...
...
@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBGmemCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
sizeof
(
BDataType
)
*
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
get_element_space_size
();
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
6e3c786e
...
...
@@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBSmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBSmemCRegOneWarpV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static_assert
(
kBlockSize
==
get_warp_size
(),
"Check failed!"
);
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensorTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
1
&&
NWarp
==
1
,
"Check failed!"
);
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iNWarp
=
0
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
tuple
<>
,
tuple
<>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// check C-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block tensor
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
1
&&
NWarp
==
1
,
"Check failed!"
);
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
tuple
<>
,
tuple
<>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
static_assert
(
decltype
(
c_block_dstr_encode
)
::
NDimP
==
1
,
"Check failed!"
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window_tmp
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
View file @
6e3c786e
...
...
@@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
View file @
6e3c786e
...
...
@@ -182,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp
View file @
6e3c786e
...
...
@@ -180,7 +180,7 @@ struct BlockGemmASmemBRegCRegV1
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
6e3c786e
...
...
@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockWindow
Tmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
CBlockTensor
,
typename
ABlockWindow
,
typename
BBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockWindow
Tmp
&
a_block_window
_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
ABlockWindow
&
a_block_window
,
const
BBlockWindow
&
b_block_window
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
Tmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
Tmp
::
DataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
...
...
@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
_tmp
.
get_bottom_tensor_view
(),
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
...
...
@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
_tmp
.
get_bottom_tensor_view
(),
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
...
...
@@ -167,7 +167,7 @@ struct BlockGemmASmemBSmemCRegV1
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
BBlockWindow
&
b_block_window
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
_tmp
);
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
);
return
c_block_tensor
;
}
};
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
>
struct
BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
static_assert
(
MWarp
==
BlockGemmShape
::
BlockWarps
::
at
(
I0
{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"
);
static_assert
(
NWarp
==
BlockGemmShape
::
BlockWarps
::
at
(
I1
{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"
);
static_assert
(
WarpGemm
::
kM
==
BlockGemmShape
::
WarpTile
::
at
(
I0
{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!"
);
static_assert
(
WarpGemm
::
kN
==
BlockGemmShape
::
WarpTile
::
at
(
I1
{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!"
);
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static_assert
(
MIterPerWarp
*
MWarp
*
WarpGemm
::
kM
==
MPerBlock
,
"Error! Warps should cover all Block tile!"
);
static_assert
(
NIterPerWarp
*
NWarp
*
WarpGemm
::
kN
==
NPerBlock
,
"Error! Warps should cover all Block tile!"
);
static
constexpr
index_t
MPerBlockPerIter
=
MWarp
*
WarpGemm
::
kM
;
static
constexpr
index_t
NPerBlockPerIter
=
NWarp
*
WarpGemm
::
kN
;
static
constexpr
index_t
KPerBlockPerIter
=
WarpGemm
::
kK
;
using
AWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}))
>
;
using
BWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}))
>
;
using
AWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
AWarpTileDistr
{}))
>
;
using
BWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BWarpTileDistr
{}))
>
;
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
WarpGemm
::
kK
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
public:
using
Traits
=
GemmTraits_
<
Problem_
,
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
using
WarpGemm
=
remove_cvref_t
<
typename
Traits
::
WarpGemm
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
static
constexpr
auto
Scheduler
=
Traits
::
Scheduler
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
private:
template
<
GemmPipelineScheduler
Scheduler
,
typename
GemmTraits
>
struct
BlockGemmImpl
{
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Default
,
GemmTraits
>
{
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I1
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kM
>
{},
number
<
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WarpGemm
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kN
>
{},
number
<
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WarpGemm
::
kN
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
GemmTraits
::
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
// TODO: I don't have to move 0,0 window!
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
const
auto
a_warp_tile
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
const
auto
b_warp_tile
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WarpGemm
{}(
c_warp_tensor
,
a_warp_tile
,
b_warp_tile
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Intrawave
,
GemmTraits
>
{
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
AWarpTile
,
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_tiles_
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
BWarpTile
,
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_tiles_
;
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I1
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kM
>
{},
number
<
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WarpGemm
::
kM
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kN
>
{},
number
<
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WarpGemm
::
kN
,
0
},
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
// TODO: I don't have to move 0,0 window!
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
});
});
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
[[
maybe_unused
]]
const
ASmemBlockWindow
&
a_block_window
,
[[
maybe_unused
]]
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WarpGemm
{}(
c_warp_tensor
,
a_warp_tiles_
[
mIter
][
kIter
],
b_warp_tiles_
[
nIter
][
kIter
]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
};
template
<
typename
GemmTraits
>
struct
BlockGemmImpl
<
GemmPipelineScheduler
::
Interwave
,
GemmTraits
>
{
static
constexpr
index_t
KPerThread
=
GemmTraits
::
KPerThread
;
static
constexpr
index_t
NumMacClusters
=
GemmTraits
::
InterWaveSchedulingMacClusters
;
static
constexpr
index_t
KPerInnerLoop
=
ck_tile
::
max
(
KPerThread
/
NumMacClusters
,
GemmTraits
::
KPack
);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
KInnerLoopIter
=
KPerInnerLoop
/
GemmTraits
::
KPack
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
AWarpTile
,
KInnerLoopIter
>
,
MIterPerWarp
>
a_warp_tiles_
;
statically_indexed_array
<
statically_indexed_array
<
typename
GemmTraits
::
BWarpTile
,
KInnerLoopIter
>
,
NIterPerWarp
>
b_warp_tiles_
;
template
<
index_t
KIdx
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
GemmTraits
::
MPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
NPerBlock
==
BSmemBlockWindow
{}.
get_window_lengths
()[
I0
{}]
&&
GemmTraits
::
KPerBlock
==
ASmemBlockWindow
{}.
get_window_lengths
()[
I1
{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"
);
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ASmemBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BSmemBlockWindow
::
DataType
>
,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"
);
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
-
(
iMWarp
*
NWarp
);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kM
>
{},
number
<
WarpGemm
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WarpGemm
::
kM
,
KIdx
*
KPerInnerLoop
},
make_static_tile_distribution
(
typename
WarpGemm
::
AWarpDstrEncoding
{}));
using
AWarpWindow
=
remove_cvref_t
<
decltype
(
a_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
AWarpTile
::
get_num_of_dimension
()
==
AWarpWindow
::
get_num_of_dimension
(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
AWarpTile
::
get_lengths
()
==
AWarpWindow
{}.
get_window_lengths
(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
AWarpWindow
,
KInnerLoopIter
>
,
MIterPerWarp
>
a_warp_windows
;
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WarpGemm
::
kN
>
{},
number
<
WarpGemm
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WarpGemm
::
kN
,
KIdx
*
KPerInnerLoop
},
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}));
using
BWarpWindow
=
remove_cvref_t
<
decltype
(
b_warp_window_tmp
)
>
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_num_of_dimension
()
==
BWarpWindow
::
get_num_of_dimension
(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!"
);
static_assert
(
GemmTraits
::
BWarpTile
::
get_lengths
()
==
BWarpWindow
{}.
get_window_lengths
(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"
);
statically_indexed_array
<
statically_indexed_array
<
BWarpWindow
,
KInnerLoopIter
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
GemmTraits
::
MPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
GemmTraits
::
NPerBlockPerIter
,
kIter
*
GemmTraits
::
KPerBlockPerIter
});
});
});
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
});
});
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
static_assert
(
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!"
);
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
kIter
)
{
LocalPrefetch
<
kIter
.
value
>
(
a_block_window
,
b_block_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
kIter
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kInnerIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if
constexpr
(
kIter
.
value
==
KRepeat
-
1
&&
kInnerIter
.
value
==
KInnerLoopIter
-
1
&&
mIter
.
value
==
MIterPerWarp
-
1
&&
nIter
.
value
==
NIterPerWarp
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// warp GEMM
WarpGemm
{}(
c_warp_tensor
,
a_warp_tiles_
[
mIter
][
kInnerIter
],
b_warp_tiles_
[
nIter
][
kInnerIter
]);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
if
constexpr
(
kInnerIter
.
value
==
0
&&
mIter
.
value
==
0
&&
nIter
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
};
public:
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
block_gemm_impl_
.
LocalPrefetch
(
a_block_window
,
b_block_window
);
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
block_gemm_impl_
(
c_block_tensor
,
a_block_window
,
b_block_window
);
}
// C = A * B
template
<
typename
ASmemBlockWindow
,
typename
BSmemBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
{
auto
c_block_tensor
=
MakeCBlockTile
();
block_gemm_impl_
(
c_block_tensor
,
a_block_window
,
b_block_window
);
return
c_block_tensor
;
}
private:
BlockGemmImpl
<
Scheduler
,
Traits
>
block_gemm_impl_
{};
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
struct
BatchedGemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
struct
BatchedGemmKargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
using
Kargs
=
BatchedGemmKargs
;
using
Hargs
=
BatchedGemmHostArgs
;
__host__
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
return
TilePartitioner
::
GridSize
(
h
.
M
,
h
.
N
,
h
.
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKargs
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
a_ptr
=
h
.
a_ptr
;
k
.
b_ptr
=
h
.
b_ptr
;
k
.
c_ptr
=
h
.
c_ptr
;
k
.
M
=
h
.
M
;
k
.
N
=
h
.
N
;
k
.
K
=
h
.
K
;
k
.
stride_A
=
h
.
stride_A
;
k
.
stride_B
=
h
.
stride_B
;
k
.
stride_C
=
h
.
stride_C
;
k
.
batch_stride_A
=
h
.
batch_stride_A
;
k
.
batch_stride_B
=
h
.
batch_stride_B
;
k
.
batch_stride_C
=
h
.
batch_stride_C
;
k
.
batch_count
=
h
.
batch_count
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
// clang-format on
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
6e3c786e
...
...
@@ -3,38 +3,34 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
Layout
A
=
remove_cvref_t
<
Layout
A_
>
;
using
Layout
B
=
remove_cvref_t
<
Layout
B_
>
;
using
Layout
C
=
remove_cvref_t
<
Layout
C_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
k
BlockSize
;
using
A
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
A
Layout
>
;
using
B
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
B
Layout
>
;
using
C
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
C
Layout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
C
O
DataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
//
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
_size
,
index_t
N
_size
,
index_t
Batch
_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
K
Batch
)
{
return
TilePartitioner
::
GridSize
(
M
_size
,
N_size
,
Batch
_size
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
K
Batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
...
...
@@ -44,34 +40,103 @@ struct GemmKernel
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
float
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
)
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
epsilon
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmCommonKargs
&
kargs
)
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
kM
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
kN
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
kN
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeC
!=
0
)
{
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
kM
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeC
!=
0
)
{
return
false
;
}
}
return
true
;
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
...
...
@@ -82,13 +147,13 @@ struct GemmKernel
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
A
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
A
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
Alignment
A
>
{},
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSize
A
>
{},
number
<
1
>
{});
}
else
...
...
@@ -96,51 +161,74 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
AlignmentA
>
{},
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
B
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
B
Layout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{},
number
<
1
>
{});
}
else
{
// Default NK layout
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
Alignment
B
>
{},
number
<
GemmPipeline
::
VectorSize
B
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadA
?
1
:
0
>
{});
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
AB
lock
W
indow
=
make_tile_window
(
auto
a_b
lock
_w
indow
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadB
?
1
:
0
>
{});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
BB
lock
W
indow
=
make_tile_window
(
auto
b_b
lock
_w
indow
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
...
...
@@ -148,20 +236,21 @@ struct GemmKernel
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
(
kargs
.
K
+
TilePartitioner
::
kK
-
1
)
/
TilePartitioner
::
kK
;
auto
acc
=
GemmPipeline
{}(
ABlockWindow
,
BBlockWindow
,
num_loop
,
smem_ptr
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
CODataType
*
c_start
=
static_cast
<
CODataType
*>
(
kargs
.
c_ptr
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
C
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
C
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
Alignment
C
>
{},
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSize
C
>
{},
number
<
1
>
{});
}
else
...
...
@@ -169,22 +258,34 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
AlignmentC
>
{},
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadC
?
1
:
0
>
{});
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
6e3c786e
...
...
@@ -9,26 +9,66 @@ namespace ck_tile {
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
{
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
batch_size
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
{
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
ck_tile
::
index_t
GridDimZ
=
batch_size
;
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
{
return
integer_divide_ceil
(
K
,
kK
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
return
ck_tile
::
make_tuple
(
iM
,
iN
);
return
make_tuple
(
iM
,
iN
);
}
};
template
<
typename
BlockGemmShape_
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
{
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
CK_TILE_DEVICE
auto
operator
()(
index_t
blockOffset
,
index_t
NBlockSize
)
{
index_t
iM
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
/
GetNBlock
(
NBlockSize
)
*
MPerBlock
);
index_t
iN
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
%
GetNBlock
(
NBlockSize
)
*
NPerBlock
);
return
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host.hpp"
namespace
ck_tile
{
struct
GroupedGemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GroupedGemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
struct
GemmTransKernelArg
{
GroupedGemmHostArgs
group_karg
;
ck_tile
::
index_t
block_start
;
ck_tile
::
index_t
block_end
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
GroupedGemmHostArgs
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
:
group_karg
{
karg
},
block_start
{
bl_start
},
block_end
{
bl_end
}
{
}
};
__host__
static
size_t
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
{
return
gemm_descs
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
using
Hargs
=
GroupedGemmHostArgs
;
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
{
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
{
const
auto
dim3
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
dim3
.
x
*
dim3
.
y
*
1
;
}
return
dim3
(
grid_size
,
1
,
1
);
}
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
{
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
group_count
=
ck_tile
::
type_convert
<
ck_tile
::
index_t
>
(
gemm_descs
.
size
());
index_t
grid_size
=
0
;
gemm_kernel_args_
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M
;
const
index_t
N
=
gemm_descs
[
i
].
N
;
const
index_t
K
=
gemm_descs
[
i
].
K
;
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
*
1
*
1
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
auto
karg
=
GroupedGemmHostArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
return
gemm_kernel_args_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
Run
(
const
Hargs
&
kargs
,
const
index_t
block_start
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}(
block_start
,
kargs
.
N
);
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
int
group_count
)
const
{
const
index_t
block_id
=
ck_tile
::
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmTransKernelArg
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end
))
&&
left
<=
right
)
{
if
(
block_id
<
gemm_desc_ptr
[
group_id
].
block_start
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
Run
(
gemm_desc_ptr
[
group_id
].
group_karg
,
gemm_desc_ptr
[
group_id
].
block_start
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
deleted
100644 → 0
View file @
1bb510cb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define VectorLoadSize 16
namespace
ck_tile
{
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
struct
BlockGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
static
constexpr
index_t
AlignmentC
=
kPadC
?
1
:
VectorLoadSize
/
sizeof
(
CDataType
);
};
}
// namespace ck_tile
Prev
1
…
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment