Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e076a320
Commit
e076a320
authored
Feb 10, 2025
by
feifei14119
Browse files
debug a
parent
839cf897
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
742 additions
and
608 deletions
+742
-608
example/ck_tile/18_flatmm/CMakeLists.txt
example/ck_tile/18_flatmm/CMakeLists.txt
+1
-1
example/ck_tile/18_flatmm/run_flatmm_example.inc
example/ck_tile/18_flatmm/run_flatmm_example.inc
+4
-4
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+7
-33
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp
...ile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp
+442
-0
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp
.../block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp
+38
-0
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp
...block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp
+59
-0
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
+59
-13
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
...s/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
+114
-23
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...ne/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+0
-517
include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp
...mm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp
+18
-17
No files found.
example/ck_tile/18_flatmm/CMakeLists.txt
View file @
e076a320
...
...
@@ -5,5 +5,5 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
list
(
APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter
)
list
(
APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef
)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -ggdb -g -O0 -v -save-temps)
list
(
APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1
)
list
(
APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1
-DDEBUG_CNT=64
)
target_compile_options
(
tile_example_flatmm_basic PRIVATE
${
EXAMPLE_FLATMM_COMPILE_OPTIONS
}
)
example/ck_tile/18_flatmm/run_flatmm_example.inc
View file @
e076a320
...
...
@@ -183,9 +183,9 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
#if FEIFEI_DEBUG
ck_tile
::
HostTensor
<
int
>
dbg_int
({
M
*
N
*
64
});
ck_tile
::
HostTensor
<
float
>
dbg_fp32
({
M
*
N
*
64
});
ck_tile
::
HostTensor
<
ADataType
>
dbg_f168
({
M
*
N
*
64
});
ck_tile
::
HostTensor
<
int
>
dbg_int
({
M
*
N
*
DEBUG_CNT
});
ck_tile
::
HostTensor
<
float
>
dbg_fp32
({
M
*
N
*
DEBUG_CNT
});
ck_tile
::
HostTensor
<
ADataType
>
dbg_f168
({
M
*
N
*
DEBUG_CNT
});
ck_tile
::
DeviceMem
dbg_int_buf
(
dbg_int
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dbg_fp32_buf
(
dbg_fp32
.
get_element_space_size_in_bytes
());
...
...
@@ -362,7 +362,7 @@ int run_flatmm_example_with_layouts(int argc,
int
GridDimY
=
1
;
int
BlockDimX
=
64
;
int
BlockDimY
=
4
;
int
DbgCnt
=
64
;
int
DbgCnt
=
DEBUG_CNT
;
int
BlockSize
=
BlockDimX
*
BlockDimY
;
// a_host
{
...
...
include/ck_tile/ops/flatmm.hpp
View file @
e076a320
...
...
@@ -3,40 +3,9 @@
#pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
// #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
// #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
...
...
@@ -45,10 +14,15 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
// block
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
// pipeline
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp"
// kernel
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
...
...
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp
0 → 100644
View file @
e076a320
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
BlockPolicy_
=
BlockFlatmmASmemBSmemCRegV1DefaultPolicy
>
struct
BlockFlatmmASmemBSmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockPolicy
=
remove_cvref_t
<
BlockPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
BlockPolicy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
#if 1
// C += A * B
// template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
template
<
typename
ABlockWindow
>
CK_TILE_DEVICE
void
operator
()(
const
ABlockWindow
&
a_block_window
#if FEIFEI_DEBUG
,
const
BDataType
*
b_ptr
,
int
*
dbg_int
,
float
*
dbg_fp32
,
void
*
dbg_f168
#endif
)
const
{
#if FEIFEI_DEBUG
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[BLOCK ] BlockFlatmmASmemBSmemCRegV1():
\n
"
);
}
uint32_t
tidx
=
threadIdx
.
x
;
uint32_t
tidy
=
threadIdx
.
y
;
uint32_t
bidx
=
blockIdx
.
x
;
uint32_t
bidy
=
blockIdx
.
y
;
uint32_t
bdmx
=
blockDim
.
x
;
uint32_t
bdmy
=
blockDim
.
y
;
uint32_t
gdmx
=
gridDim
.
x
;
uint32_t
gdmy
=
gridDim
.
y
;
uint32_t
gid
=
((
bdmx
*
bdmy
)
*
gdmx
)
*
bidy
+
(
bdmx
*
bdmy
)
*
bidx
+
bdmx
*
tidy
+
tidx
;
half_t
*
dbg_f16
=
static_cast
<
half_t
*>
(
dbg_f168
);
for
(
int
i
=
0
;
i
<
DEBUG_CNT
;
i
++
)
{
dbg_int
[
gid
*
DEBUG_CNT
+
i
]
=
-
1
;
dbg_fp32
[
gid
*
DEBUG_CNT
+
i
]
=
-
1.0
f
;
dbg_f16
[
gid
*
DEBUG_CNT
+
i
]
=
ck_tile
::
type_convert
<
ck_tile
::
half_t
>
(
-
1.0
f
);
}
#endif
/*
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
*/
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
// constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}];
/*
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
*/
constexpr
auto
config
=
BlockPolicy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
// constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
MPerBlockPerIter
=
MPerBlock
/
MIterPerWarp
;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
statically_indexed_array
<
statically_indexed_array
<
decltype
(
a_warp_window_tmp
),
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
MPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
// Warp loop in block:
constexpr
index_t
kIter
=
0
;
constexpr
index_t
mIter
=
0
;
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
number
<
mIter
>
{})(
number
<
kIter
>
{}));
#if 1
// feifei TODO: Implement gemm here
#else
constexpr
auto
config
=
BlockPolicy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
MPerBlockPerIter
=
MPerBlock
/
MIterPerWarp
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
a_warp_window_tmp
),
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
MPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
#endif
}
#else
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockWindow
,
typename
BBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockWindow
&
a_block_window
,
const
BBlockWindow
&
b_block_window
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
BlockPolicy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
MPerBlockPerIter
=
MPerBlock
/
MIterPerWarp
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
a_warp_window_tmp
),
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
MPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindow
&
b_block_window
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
);
return
c_block_tensor
;
}
#endif
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp
0 → 100644
View file @
e076a320
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockFlatmmASmemBSmemCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp
0 → 100644
View file @
e076a320
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockFlatmmASmemBSmemCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
#endif
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
else
{
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
View file @
e076a320
...
...
@@ -85,7 +85,7 @@ struct FlatmmKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
);
return
TilePartitioner
::
GridSize
(
M
,
N
);
// feifei TODO: split K here
// return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
...
...
@@ -178,7 +178,7 @@ struct FlatmmKernel
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
index_t
splitted_k
;
index_t
splitted_k
;
// problem K after splitted
};
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
FlatmmKernelArgs
&
kargs
)
...
...
@@ -473,7 +473,7 @@ struct FlatmmKernel
const
BDataType
*
b_ptr
,
int
*
dbg_int
,
float
*
dbg_fp32
,
short
*
dbg_f168
void
*
dbg_f168
#endif
)
{
...
...
@@ -481,12 +481,55 @@ struct FlatmmKernel
// Create Flatmm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_shuffle_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
// origin layout
// const auto& gemm_tensor_views_tuple =
//
MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
//
Debug
origin layout
// const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(
// a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
auto
&
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
#if FEIFEI_DEBUG
////////////////////////////////////////////////////////
const
auto
&
a_gemm_tensor_views
=
gemm_tensor_views_tuple
.
at
(
I0
);
// tensor_view
const
auto
&
a_gemm_tensor_desc
=
a_gemm_tensor_views
.
desc_
;
// tensor_descriptor
const
auto
&
a_gemm_buff_views
=
a_gemm_tensor_views
.
buf_
;
// buffer_view
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[KERNEL] a_gemm_tensor_view: size = %ld, len = [%d, %d], top = [%d, %d], upper = %d, lower = %d
\n
"
,
a_gemm_tensor_desc
.
get_element_space_size
(),
a_gemm_tensor_desc
.
get_length
(
I0
),
a_gemm_tensor_desc
.
get_length
(
I1
),
a_gemm_tensor_desc
.
get_top_dimension_hidden_ids
()[
0
],
a_gemm_tensor_desc
.
get_top_dimension_hidden_ids
()[
1
],
a_gemm_tensor_desc
.
get_upper_dimension_hidden_idss
()(
I0
)[
0
],
a_gemm_tensor_desc
.
get_lower_dimension_hidden_idss
()(
I0
)[
0
]
);
}
const
auto
&
a_pad_tensor_views
=
gemm_pad_views
.
at
(
I0
);
// tensor_view
const
auto
&
a_pad_tensor_desc
=
a_pad_tensor_views
.
desc_
;
// tensor_descriptor
const
auto
&
a_pad_buff_views
=
a_pad_tensor_views
.
buf_
;
// buffer_view
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[KERNEL] a_pad_tensor_view: size = %ld, len = [%d, %d], top = [%d, %d], upper = %d, lower = %d
\n
"
,
a_pad_tensor_desc
.
get_element_space_size
(),
a_pad_tensor_desc
.
get_length
(
I0
),
a_pad_tensor_desc
.
get_length
(
I1
),
a_pad_tensor_desc
.
get_top_dimension_hidden_ids
()[
0
],
a_pad_tensor_desc
.
get_top_dimension_hidden_ids
()[
1
],
a_pad_tensor_desc
.
get_upper_dimension_hidden_idss
()(
I0
)[
0
],
a_pad_tensor_desc
.
get_lower_dimension_hidden_idss
()(
I0
)[
0
]
);
}
const
auto
&
a_tile_win
=
gemm_tile_windows
.
at
(
I0
);
// tile_window_with_static_lengths
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[KERNEL] a_gemm_tile_window: dim_num = %d
\n
"
,
a_tile_win
.
get_num_of_dimension
()
);
}
////////////////////////////////////////////////////////
#endif
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
...
...
@@ -555,11 +598,14 @@ struct FlatmmKernel
int
*
dbg_int
=
static_cast
<
int
*>
(
kargs
.
dbg_int_ptr
);
float
*
dbg_fp32
=
static_cast
<
float
*>
(
kargs
.
dbg_fp32_ptr
);
shor
t
*
dbg_f16
8
=
static_cast
<
shor
t
*>
(
kargs
.
dbg_f168_ptr
);
half_
t
*
dbg_f16
=
static_cast
<
half_
t
*>
(
kargs
.
dbg_f168_ptr
);
dbg_int
[
gid
]
=
1
;
dbg_fp32
[
gid
]
=
1.0
f
;
dbg_f168
[
gid
]
=
ck_tile
::
type_convert
<
ck_tile
::
half_t
>
(
1.0
f
);
for
(
int
i
=
0
;
i
<
DEBUG_CNT
;
i
++
)
{
dbg_int
[
gid
*
DEBUG_CNT
+
i
]
=
0
;
dbg_fp32
[
gid
*
DEBUG_CNT
+
i
]
=
.0
f
;
dbg_f16
[
gid
*
DEBUG_CNT
+
i
]
=
ck_tile
::
type_convert
<
ck_tile
::
half_t
>
(
0.
f
);
}
#endif
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
...
...
@@ -592,7 +638,7 @@ struct FlatmmKernel
b_ptr
,
dbg_int
,
dbg_fp32
,
dbg_f168
kargs
.
dbg_f168
_ptr
#endif
);
}
...
...
@@ -611,7 +657,7 @@ struct FlatmmKernel
b_ptr
,
dbg_int
,
dbg_fp32
,
dbg_f168
kargs
.
dbg_f168
_ptr
#endif
);
}
...
...
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
e076a320
...
...
@@ -4,14 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_ag
mem_bgmem_creg_v1_default
_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_
universal_
pipeline_ag
_bg_cr
_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
P
olicy
=
FlatmmPipelineAGmemBGmemCRegV1D
efault
P
olicy
>
template
<
typename
Problem
,
typename
P
ipelinePolicy
=
UniversalFlatmmPipelineAgBgCrPolicy
>
// feifei TODO: add d
efault
p
olicy
struct
FlatmmPipelineAGmemBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
...
...
@@ -23,7 +23,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
BlockFlatmm
=
remove_cvref_t
<
decltype
(
PipelinePolicy
::
template
GetBlockFlatmm
<
Problem
>())
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
...
...
@@ -41,21 +42,24 @@ struct FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>()
.
get_element_space_size
(),
16
)
*
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
PipelinePolicy
::
template
MakeALdsBlockDescriptor
<
Problem
>()
.
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>()
.
get_element_space_size
();
sizeof
(
BDataType
)
*
PipelinePolicy
::
template
MakeBLdsBlockDescriptor
<
Problem
>()
.
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Pipeline
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
PipelinePolicy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
...
...
@@ -72,7 +76,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
const
BDataType
*
b_ptr
,
int
*
dbg_int
,
float
*
dbg_fp32
,
short
*
dbg_f168
void
*
dbg_f168
#endif
)
const
{
...
...
@@ -80,6 +84,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] FlatmmPipelinen():
\n
"
);
printf
(
"[PIPELN] num_loop = %d
\n
"
,
num_loop
);
}
uint32_t
tidx
=
threadIdx
.
x
;
...
...
@@ -92,9 +97,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1
uint32_t
gdmy
=
gridDim
.
y
;
uint32_t
gid
=
((
bdmx
*
bdmy
)
*
gdmx
)
*
bidy
+
(
bdmx
*
bdmy
)
*
bidx
+
bdmx
*
tidy
+
tidx
;
dbg_int
[
gid
]
=
-
1
;
dbg_fp32
[
gid
]
=
-
1.0
f
;
dbg_f168
[
gid
]
=
ck_tile
::
type_convert
<
ck_tile
::
half_t
>
(
-
1.0
f
);
half_t
*
dbg_f16
=
static_cast
<
half_t
*>
(
dbg_f168
);
for
(
int
i
=
0
;
i
<
DEBUG_CNT
;
i
++
)
{
dbg_int
[
gid
*
DEBUG_CNT
+
i
]
=
1
;
dbg_fp32
[
gid
*
DEBUG_CNT
+
i
]
=
1.0
f
;
dbg_f16
[
gid
*
DEBUG_CNT
+
i
]
=
ck_tile
::
type_convert
<
ck_tile
::
half_t
>
(
1.0
f
);
}
#endif
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -108,12 +117,93 @@ struct FlatmmPipelineAGmemBGmemCRegV1
#if 1
// feifei TODO: Implement gemm here
// Get block flatmm
auto
block_flatmm
=
BlockFlatmm
();
// struct BlockFlatmmASmemBSmemCRegV1
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
PipelinePolicy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
PipelinePolicy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Prefetch -----------------------------------------------------------
// global read 0
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
#if FEIFEI_DEBUG // debug A global load
int
a_dim
=
a_block_tile
.
get_num_of_dimension
();
int
a_sz
=
a_block_tile
.
get_thread_buffer_size
();
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"[PIPELN] a_dim = %d, a_sz = %d
\n
"
,
a_dim
,
a_sz
);
}
for
(
auto
i
=
0
;
i
<
a_sz
;
i
++
)
{
dbg_f16
[
gid
*
DEBUG_CNT
+
i
]
=
a_block_tile
.
get_thread_buffer
()[
i
];
}
return
nullptr
;
#endif
// move to 1
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
// LDS write 0
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
PipelinePolicy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// Loop ---------------------------------------------------------------
// Do flatmm
block_flatmm
(
a_lds_gemm_window
#if FEIFEI_DEBUG
,
b_ptr
,
dbg_int
,
dbg_fp32
,
dbg_f168
#endif
);
// Tail ---------------------------------------------------------------
return
nullptr
;
#else
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
auto
a_lds_block_desc
=
PipelinePolicy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
...
...
@@ -125,7 +215,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
auto
b_lds_block_desc
=
PipelinePolicy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
...
...
@@ -134,7 +225,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Pipeline
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
...
...
@@ -145,7 +236,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Pipeline
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
...
...
@@ -184,7 +275,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
Pipeline
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
...
...
@@ -198,7 +289,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
Pipeline
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
...
...
@@ -235,7 +326,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
Pipeline
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
...
...
@@ -271,7 +362,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
const
BDataType
*
b_ptr
,
int
*
dbg_int
,
float
*
dbg_fp32
,
short
*
dbg_f168
void
*
dbg_f168
#endif
)
const
{
...
...
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
deleted
100644 → 0
View file @
839cf897
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
e076a320
...
...
@@ -446,24 +446,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Ge
mm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Flat
mm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockFlatmmPolicy
=
BlockFlatmmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockFlatmmASmemBSmemCRegV1
<
Problem
,
BlockFlatmmPolicy
>
{};
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment