Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fa335f31
Commit
fa335f31
authored
Jan 07, 2025
by
feifei14119
Browse files
remove debug 9.8 tflops
parent
888317e6
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
3746 additions
and
0 deletions
+3746
-0
example/ck_tile/18_flatmm_uk/CMakeLists.txt
example/ck_tile/18_flatmm_uk/CMakeLists.txt
+19
-0
example/ck_tile/18_flatmm_uk/flatmm_uk.hpp
example/ck_tile/18_flatmm_uk/flatmm_uk.hpp
+100
-0
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp
+192
-0
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp
+76
-0
example/ck_tile/18_flatmm_uk/main.cpp
example/ck_tile/18_flatmm_uk/main.cpp
+692
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp
.../ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp
+665
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
.../block/uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
+574
-0
include/ck_tile/ops/flatmm_uk.hpp
include/ck_tile/ops/flatmm_uk.hpp
+17
-0
include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp
+264
-0
include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp
...ude/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp
+337
-0
include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp
...tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp
+809
-0
No files found.
example/ck_tile/18_flatmm_uk/CMakeLists.txt
0 → 100644
View file @
fa335f31
set
(
TILE_EXAPMLE_FLATMM_UK
"tile_example_flatmm_uk"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding
${
TILE_EXAPMLE_FLATMM_UK
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
TILE_EXAPMLE_FLATMM_UK
}
EXCLUDE_FROM_ALL main.cpp
)
target_include_directories
(
${
TILE_EXAPMLE_FLATMM_UK
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
TILE_EXAPMLE_FLATMM_UK
}
PRIVATE
${
INSTANCE_SRCS
}
)
set
(
TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1
)
# TODO: enable load to a
list
(
APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4
)
# rta
# list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
# list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options
(
${
TILE_EXAPMLE_FLATMM_UK
}
PRIVATE
${
TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS
}
)
example/ck_tile/18_flatmm_uk/flatmm_uk.hpp
0 → 100644
View file @
fa335f31
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/flatmm_uk.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FlatmmUkTypeConfig
;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FlatmmUkTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FlatmmUkTypeConfig
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
fp16_t
;
using
GDataType
=
ck_tile
::
fp16_t
;
using
DDataType
=
ck_tile
::
fp16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
fp16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FlatmmUkTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
DDataType
=
ck_tile
::
int8_t
;
using
AccDataType
=
int32_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
struct
flatmm_uk_args
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
b_ptr
;
// [m, k], input token
const
void
*
c_ptr
;
// [m, k], output token (no need to do zeroing)
void
*
d_ptr
;
// [m, k], output token (no need to do zeroing)
void
*
dbg_int_ptr
;
// [m, k], output token (no need to do zeroing)
void
*
dbg_bf16_ptr
;
// [m, k], output token (no need to do zeroing)
void
*
dbg_fp32_ptr
;
// [m, k], output token (no need to do zeroing)
ck_tile
::
index_t
block_m
;
// block_m, used to devide the input
ck_tile
::
index_t
hidden_size
;
// k
ck_tile
::
index_t
intermediate_size
;
// n / TP, for Gate. if Gate+Up, Down need divide by 2
ck_tile
::
index_t
num_tokens
;
// input number of tokens for current iteration
ck_tile
::
index_t
num_experts
;
// number of groups
ck_tile
::
index_t
topk
;
// need this?
ck_tile
::
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// This is the public API, will be generated by script
struct
flatmm_uk_traits
{
std
::
string
prec_i
;
// input precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_kw
;
// topk-weight data type
int
block_m
;
int
gate_only
;
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float
flatmm_uk
(
flatmm_uk_traits
,
flatmm_uk_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp
0 → 100644
View file @
fa335f31
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "flatmm_uk.hpp"
#include "flatmm_uk_api.hpp"
#include "ck_tile/ops/flatmm_uk.hpp"
#include <iostream>
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template
<
typename
Ts_
>
float
flatmm_uk_
(
const
ck_tile
::
stream_config
&
s_
,
flatmm_uk_args_
a_
)
{
printf
(
"[FF] ======= fused_moegemm_() =======
\n
\t
get moe arg in a_ <flatmm_uk_args>, get "
"config in Ts_
\n
"
);
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
>
;
printf
(
"[FF] --- fused_moegemm_(): <FusedMoeGemmShape> ---
\n
"
);
printf
(
"[FF] f_shape::BlockSize = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
BlockSize
));
printf
(
"[FF] f_shape::NumWarps = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
NumWarps
));
printf
(
"[FF] ---------
\n
"
);
printf
(
"[FF] f_shape::Block_M0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_M0
));
printf
(
"[FF] f_shape::Block_N0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_N0
));
printf
(
"[FF] f_shape::Block_K0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_K0
));
printf
(
"[FF] f_shape::WarpPerBlock_M0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
WarpPerBlock_M0
));
printf
(
"[FF] f_shape::WarpPerBlock_N0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
WarpPerBlock_N0
));
printf
(
"[FF] f_shape::WarpPerBlock_K0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
WarpPerBlock_K0
));
printf
(
"[FF] f_shape::Warp_M0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Warp_M0
));
printf
(
"[FF] f_shape::Warp_N0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Warp_N0
));
printf
(
"[FF] f_shape::Warp_K0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Warp_K0
));
printf
(
"[FF] f_shape::ThreadPerBlock_M0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
ThreadPerBlock_M0
));
printf
(
"[FF] f_shape::ThreadPerBlock_N0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
ThreadPerBlock_N0
));
printf
(
"[FF] f_shape::ThreadPerBlock_K0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
ThreadPerBlock_K0
));
printf
(
"[FF] f_shape::Repeat_M0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Repeat_M0
));
printf
(
"[FF] f_shape::Repeat_N0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Repeat_N0
));
printf
(
"[FF] f_shape::Repeat_K0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Repeat_K0
));
printf
(
"[FF] f_shape::Block_W0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_W0
));
printf
(
"[FF] f_shape::Block_Nr0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_Nr0
));
printf
(
"[FF] f_shape::Block_Kr0 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_Kr0
));
printf
(
"[FF] ---------
\n
"
);
printf
(
"[FF] f_shape::Block_M1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_M1
));
printf
(
"[FF] f_shape::Block_N1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_N1
));
printf
(
"[FF] f_shape::Block_K1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_K1
));
printf
(
"[FF] f_shape::WarpPerBlock_M1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
WarpPerBlock_M1
));
printf
(
"[FF] f_shape::WarpPerBlock_N1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
WarpPerBlock_N1
));
printf
(
"[FF] f_shape::WarpPerBlock_K1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
WarpPerBlock_K1
));
printf
(
"[FF] f_shape::Warp_M1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Warp_M1
));
printf
(
"[FF] f_shape::Warp_N1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Warp_N1
));
printf
(
"[FF] f_shape::Warp_K1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Warp_K1
));
printf
(
"[FF] f_shape::ThreadPerBlock_M1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
ThreadPerBlock_M1
));
printf
(
"[FF] f_shape::ThreadPerBlock_N1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
ThreadPerBlock_N1
));
printf
(
"[FF] f_shape::ThreadPerBlock_K1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
ThreadPerBlock_K1
));
printf
(
"[FF] f_shape::Repeat_M1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Repeat_M1
));
printf
(
"[FF] f_shape::Repeat_N1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Repeat_N1
));
printf
(
"[FF] f_shape::Repeat_K1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Repeat_K1
));
printf
(
"[FF] f_shape::Block_W1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_W1
));
printf
(
"[FF] f_shape::Block_Nr1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_Nr1
));
printf
(
"[FF] f_shape::Block_Kr1 = %d
\n
"
,
static_cast
<
uint32_t
>
(
f_shape
::
Block_Kr1
));
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
ck_tile
::
element_wise
::
FastGeluAsm
,
// TODO: hardcoded
f_shape
,
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
GemmPipeline_FlatmmUk
<
f_problem
>
;
using
f_kernel
=
ck_tile
::
FlatmmUkKernel
<
f_pipeline
,
void
>
;
const
dim3
grids
=
f_kernel
::
GridSize
(
a_
);
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
printf
(
"[FF] grids = [%d, %d, %d]
\n
"
,
grids
.
x
,
grids
.
y
,
grids
.
z
);
printf
(
"[FF] blocks = [%d, %d, %d]
\n
"
,
blocks
.
x
,
blocks
.
y
,
blocks
.
z
);
static
int
printed
=
0
;
auto
kargs
=
f_kernel
::
MakeKargs
(
a_
);
f_kernel
kernel
{};
auto
lambda_kenrel
=
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
kernel
,
grids
,
blocks
,
0
,
kargs
);
if
(
s_
.
log_level_
>
0
&&
printed
==
10
)
{
// std::cout << ", " << f_kernel::GetName() << std::flush;
printed
=
1
;
}
return
ck_tile
::
launch_kernel
(
s_
,
lambda_kenrel
// ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs)
);
}
float
flatmm_uk
(
flatmm_uk_traits
t
,
flatmm_uk_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
// auto s_ = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
auto
s_
=
s
;
auto
t_
=
flatmm_uk_traits_
{
t
.
prec_i
,
t
.
prec_w
,
t
.
prec_o
,
t
.
prec_st
,
t
.
prec_sw
,
t
.
prec_sq
,
t
.
prec_kw
,
t
.
block_m
,
t
.
gate_only
,
t
.
fused_quant
};
auto
a_
=
flatmm_uk_args_
{
a
.
a_ptr
,
// const void* a_ptr;
a
.
b_ptr
,
// const void* a_ptr;
a
.
c_ptr
,
// void* o_ptr;
a
.
d_ptr
,
// void* o_ptr;
a
.
dbg_int_ptr
,
a
.
dbg_bf16_ptr
,
a
.
dbg_fp32_ptr
,
a
.
hidden_size
,
// index_t hidden_size;
a
.
intermediate_size
,
// index_t intermediate_size;
a
.
num_tokens
,
// index_t num_tokens;
a
.
num_experts
,
// index_t num_experts;
a
.
topk
,
// index_t topk;
a
.
stride_token
// index_t stride_token;
};
float
r
=
-
1
;
if
(
t_
.
prec_i
==
"bf16"
&&
t_
.
prec_w
==
"bf16"
&&
t_
.
prec_o
==
"bf16"
&&
t_
.
prec_st
==
"fp32"
&&
t_
.
prec_sw
==
"fp32"
&&
t_
.
prec_sq
==
"fp32"
&&
t_
.
prec_kw
==
"fp32"
&&
t_
.
block_m
==
32
&&
t_
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
flatmm_uk_
<
t_
>
(
s_
,
a_
);
}
else
if
(
t_
.
prec_i
==
"fp16"
&&
t_
.
prec_w
==
"fp16"
&&
t_
.
prec_o
==
"fp16"
&&
t_
.
prec_st
==
"fp32"
&&
t_
.
prec_sw
==
"fp32"
&&
t_
.
prec_sq
==
"fp32"
&&
t_
.
prec_kw
==
"fp32"
&&
t_
.
block_m
==
32
&&
t_
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
flatmm_uk_
<
t_
>
(
s_
,
a_
);
}
// keep unsupported case return negative
if
(
r
<
0
)
return
-
1
;
return
r
;
}
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp
0 → 100644
View file @
fa335f31
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/flatmm_uk.hpp"
#include <string>
// runtime args
struct
flatmm_uk_args_
:
public
ck_tile
::
FlatmmUkHostArgs
{
};
// This is the public API, will be generated by script
struct
flatmm_uk_traits_
{
std
::
string
prec_i
;
// input precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_kw
;
// topk-weight data type
int
block_m
;
int
gate_only
;
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
,
typename
BlockTIle_
,
// seq<b_token, b_interm, b_hidden, b_down>
typename
WarpPerBlock_
,
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
ck_tile
::
index_t
GateOnly_
=
0
,
ck_tile
::
index_t
FusedQuant_
=
0
>
struct
fmoe_
// traits, ugly name, only used for internal
{
using
TypeConfig
=
FlatmmUkTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
ADataType
>
;
using
GDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
GDataType
>
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
DDataType
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
AccDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
ODataType
>
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
AScaleDataType
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
GScaleDataType
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
DScaleDataType
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BI_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
};
example/ck_tile/18_flatmm_uk/main.cpp
0 → 100644
View file @
fa335f31
This diff is collapsed.
Click to expand it.
example/ck_tile/CMakeLists.txt
View file @
fa335f31
...
...
@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
16_batched_gemm
)
add_subdirectory
(
17_grouped_gemm
)
add_subdirectory
(
18_flatmm_uk
)
include/ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp
0 → 100644
View file @
fa335f31
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
0 → 100644
View file @
fa335f31
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm_uk.hpp
0 → 100644
View file @
fa335f31
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp
0 → 100644
View file @
fa335f31
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace
ck_tile
{
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct
FlatmmUkHostArgs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
b_ptr
;
// [m, k], input token
const
void
*
c_ptr
;
// [m, k], output token
void
*
d_ptr
;
// [m, k], output token
void
*
dbg_int_ptr
;
// [m, k], output token
void
*
dbg_bf16_ptr
;
// [m, k], output token
void
*
dbg_fp32_ptr
;
// [m, k], output token
index_t
hidden_size
;
// K
index_t
intermediate_size
;
// N
index_t
num_tokens
;
// M
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template
<
typename
Pipeline_
,
typename
Epilogue_
>
struct
FlatmmUkKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
static
constexpr
index_t
BlockSize_
=
BlockShape
::
BlockSize
;
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
DDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
AccDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
ODataType
=
typename
Pipeline
::
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Pipeline
::
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Pipeline
::
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Pipeline
::
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Pipeline
::
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Pipeline
::
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Pipeline
::
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
UseUK
=
true
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
S_
=
BlockShape
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
ADataType
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
GDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
GDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M0
)
+
"x"
+
_TS_
(
S_
::
Block_N0
)
+
"x"
+
_TS_
(
S_
::
Block_K0
)
+
"x"
+
_TS_
(
S_
::
Block_N1
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_K0
)
+
"_"
+
_TS_
(
S_
::
Warp_M0
)
+
"x"
+
_TS_
(
S_
::
Warp_N0
)
+
"x"
+
_TS_
(
S_
::
Warp_K0
)
+
"_"
+
_SS_
(
Pipeline
::
name
);
#undef _SS_
#undef _TS_
// clang-format on
}
struct
FusedMoeGemmKargs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
b_ptr
;
// [m, k], input token
const
void
*
c_ptr
;
// [m, k], output token
void
*
d_ptr
;
// [m, k], output token
void
*
dbg_int_ptr
;
// [m, k], output token
void
*
dbg_bf16_ptr
;
// [m, k], output token
void
*
dbg_fp32_ptr
;
// [m, k], output token
index_t
hidden_size
;
// K
index_t
intermediate_size
;
// N
index_t
num_tokens
;
// M
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeGemmKargs
;
using
Hargs
=
FlatmmUkHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
// TODO: hargs/kargs not guranteed to be the same
return
bit_cast
<
Kargs
>
(
hargs
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
hargs
.
num_tokens
,
BlockShape
::
Block_M0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
hargs
.
intermediate_size
,
BlockShape
::
Block_N0
);
return
dim3
(
ns
,
ms
,
1
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[KERNEL] FlatmmUkKernel =====\n");
printf("[KERNEL] blockDim: [%d, %d], gridDim: [%d, %d]\n",
static_cast<int>(blockDim.x),
static_cast<int>(blockDim.y),
static_cast<int>(gridDim.x),
static_cast<int>(gridDim.y));
printf("[KERNEL] lds = %.3f (KB)\n", GetSmemSize() / 1024.0f);
}
[[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255
[[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0
[[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1
[[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51
[[maybe_unused]] uint32_t bdmx = blockDim.x; // 256
[[maybe_unused]] uint32_t bdmy = blockDim.y; // 1
[[maybe_unused]] uint32_t gdmx = gridDim.x; // 2
[[maybe_unused]] uint32_t gdmy = gridDim.y; // 52
[[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy
+ (bdmx * bdmy) * bidx
+ bdmx * tidy
+ tidx;
[[maybe_unused]]int * dbg_int = static_cast<int*>(kargs.dbg_int_ptr);
[[maybe_unused]]short * dbg_bf16 = static_cast<short*>(kargs.dbg_bf16_ptr);
[[maybe_unused]]float * dbg_fp32 = static_cast<float*>(kargs.dbg_fp32_ptr);
dbg_int[gid] = -1;
dbg_fp32[gid] = -1.0f;
#endif
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
Pipeline
{}(
kargs
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp
0 → 100644
View file @
fa335f31
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
GemmPipelineFlatmmPolicy
>
struct
GemmPipeline_FlatmmUk
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
AccDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"flatmm_uk"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
smem_bridge
);
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
CK_TILE_DEVICE
constexpr
auto
GetNumRowCoords_A
()
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
return
MRepeat
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
CK_TILE_DEVICE
auto
GetRowCoords_O2
(
index_t
base_offset
)
{
constexpr
index_t
NLans
=
BlockShape
::
Block_N0
/
kAlignmentO
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
NLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
NLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
row_ids
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetWeightScale
(
const
ROW_COORDS
coords
,
const
TopkWeightDataType
*
sorted_weight_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
w
;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
constexpr
index_t
MLanes
=
BlockShape
::
Warp_M1
;
constexpr
index_t
Repeat_M
=
BlockShape
::
Repeat_M1
;
auto
base_coord
=
threadIdx
.
x
%
MLanes
+
base_offset
;
array
<
index_t
,
Repeat_M
>
coords
;
static_for
<
0
,
Repeat_M
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLanes
;
});
return
coords
;
}
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
)
{
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPE] GemmPipeline_FlatmmUk =====\n");
}
[[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255
[[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0
[[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1
[[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51
[[maybe_unused]] uint32_t bdmx = blockDim.x; // 256
[[maybe_unused]] uint32_t bdmy = blockDim.y; // 1
[[maybe_unused]] uint32_t gdmx = gridDim.x; // 2
[[maybe_unused]] uint32_t gdmy = gridDim.y; // 52
[[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy
+ (bdmx * bdmy) * bidx
+ bdmx * tidy
+ tidx;
#endif
[[
maybe_unused
]]
int
*
dbg_int
=
static_cast
<
int
*>
(
kargs
.
dbg_int_ptr
);
[[
maybe_unused
]]
short
*
dbg_bf16
=
static_cast
<
short
*>
(
kargs
.
dbg_bf16_ptr
);
[[
maybe_unused
]]
float
*
dbg_fp32
=
static_cast
<
float
*>
(
kargs
.
dbg_fp32_ptr
);
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
// N
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
// ----------------------------------------------------------------------------
// a
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
hidden_size
*
sizeof
(
ADataType
));
auto
row_ids_a
=
GetRowCoords_A
(
blockIdx
.
y
*
BlockShape
::
Block_M0
);
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
hidden_size
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
// ----------------------------------------------------------------------------
// b
auto
b_win
=
[
&
]()
{
const
GDataType
*
b_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
b_ptr
)
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
b_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
auto
b_window_
=
make_tile_window_linear_raw
(
b_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
b_window_
;
}();
auto
b_res
=
b_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
b_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
b_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
b_win
)
::
NumAccess_NonLinear
>
{});
// ----------------------------------------------------------------------------
// core
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
b_res
,
b_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
,
// tile offset for B matrix each unroll
dbg_int
,
dbg_bf16
,
dbg_fp32
);
// ----------------------------------------------------------------------------
{
int
tid
=
threadIdx
.
x
;
float
srdfp32
=
0.
f
;
float
*
smemfp32
=
static_cast
<
float
*>
(
smem
);
// ----------------------------------------------------------------------------
// store to lds
for
(
uint32_t
accIdx
=
0
;
accIdx
<
16
;
accIdx
++
)
{
float
*
accSmem
=
smemfp32
+
4
*
blockDim
.
x
*
accIdx
;
for
(
int
xyzw
=
0
;
xyzw
<
4
;
xyzw
++
)
{
accSmem
[
tid
*
4
+
xyzw
]
=
acc_0
.
get_thread_buffer
()[
accIdx
*
4
+
xyzw
];
}
}
block_sync_lds
();
// ----------------------------------------------------------------------------
// read from lds
int
sldIdx
=
0
;
// int MLn = 15;
// int Nln = tid / MLn;
int
tidInWave
=
tid
%
64
;
int
waveId
=
tid
/
64
;
// sldIdx = (tid64 % 16 * 16 + tid64 / 16) % 64
// + tid / 64;
sldIdx
=
(
tidInWave
%
16
*
16
+
tidInWave
/
16
)
+
waveId
*
4
;
const
int
accNLane
=
16
;
const
int
NLaneCnt
=
BlockShape
::
Block_N0
/
4
;
// xyzw 512 / 4 = 128
const
int
accBlkSize
=
blockDim
.
x
;
int
accInnerId
=
tid
%
accNLane
;
// 0~15
int
accNIdx
=
tid
/
NLaneCnt
;
// 0~127 = 0; 128~255 = 1
int
acc01BlkIdx
=
tid
%
NLaneCnt
/
16
;
// 0 ~ 7
int
accBlkIdx
=
acc01BlkIdx
*
2
;
// 0, 2, 4, ..., 14
int
acc4Id
=
accBlkIdx
*
accBlkSize
//
+
accNIdx
*
accBlkSize
+
accInnerId
*
16
;
sldIdx
=
acc4Id
;
float
*
d_buf
=
static_cast
<
float
*>
(
kargs
.
d_ptr
);
int
c_blk_offset
=
blockIdx
.
y
*
BlockShape
::
Block_M0
*
kargs
.
intermediate_size
/
4
+
blockIdx
.
x
*
BlockShape
::
Block_N0
/
4
;
for
(
uint32_t
accIdx
=
0
;
accIdx
<
16
;
accIdx
++
)
{
for
(
int
xyzw
=
0
;
xyzw
<
4
;
xyzw
++
)
{
srdfp32
=
smemfp32
[
accIdx
*
(
1
*
4
)
+
sldIdx
*
4
+
xyzw
];
acc_0
.
get_thread_buffer
()[
accIdx
*
4
+
xyzw
]
=
srdfp32
;
}
// ----------------------------------------------------------------------------
// store to vmem
int
c_m_idx_offset
=
(
accIdx
+
accNIdx
*
16
)
*
kargs
.
intermediate_size
/
4
;
int
c_idx_offset
=
c_blk_offset
+
c_m_idx_offset
+
(
tid
%
NLaneCnt
);
for
(
int
xyzw
=
0
;
xyzw
<
4
;
xyzw
++
)
{
srdfp32
=
acc_0
.
get_thread_buffer
()[
accIdx
*
4
+
xyzw
];
d_buf
[
c_idx_offset
*
4
+
xyzw
]
=
srdfp32
;
}
}
}
#if 0
// ----------------------------------------------------------------------------
// debug
for(uint32_t dbgi = 0; dbgi < 64; dbgi++)
{
dbg_fp32[gid * 64 + dbgi] = acc_0.get_thread_buffer()[dbgi];
}
#endif
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp
0 → 100644
View file @
fa335f31
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment