Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cdb83933
"...composable_kernel_rocm.git" did not exist on "89ee259752fe94c74ad894496cb8cf71276ea43a"
Commit
cdb83933
authored
Sep 20, 2024
by
shengnxu
Browse files
codes backup
parent
41659ab1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
316 additions
and
75 deletions
+316
-75
example/ck_tile/05_moe/CMakeLists.txt
example/ck_tile/05_moe/CMakeLists.txt
+1
-75
example/ck_tile/05_moe/moe.cpp
example/ck_tile/05_moe/moe.cpp
+243
-0
example/ck_tile/05_moe/moe.hpp
example/ck_tile/05_moe/moe.hpp
+71
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
No files found.
example/ck_tile/05_moe/CMakeLists.txt
View file @
cdb83933
# generate a list of kernels, but not actually emit files at config stage
add_executable
(
tile_example_moe EXCLUDE_FROM_ALL moe.cpp
)
execute_process
(
\ No newline at end of file
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api fwd,fwd_splitkv --list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
)
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api bwd --list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
file
(
STRINGS
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS
)
file
(
STRINGS
${
CMAKE_CURRENT_BINARY_DIR
}
/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS
)
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api fwd,fwd_splitkv --output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
add_custom_command
(
OUTPUT
${
FMHA_BWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api bwd --output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
set
(
EXAMPLE_FMHA_FWD
"tile_example_fmha_fwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_FWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_FWD
}
EXCLUDE_FROM_ALL fmha_fwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_FWD
}
PRIVATE
${
FMHA_FWD_GEN_BLOBS
}
)
set
(
EXAMPLE_FMHA_BWD
"tile_example_fmha_bwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL fmha_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if
(
NOT DEFINED FMHA_FWD_FAST_EXP2
)
set
(
FMHA_FWD_FAST_EXP2 true
)
endif
()
set
(
EXAMPLE_FMHA_FWD_COMPILE_OPTIONS
)
set
(
EXAMPLE_FMHA_BWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
target_compile_options
(
${
EXAMPLE_FMHA_FWD
}
PRIVATE
${
EXAMPLE_FMHA_FWD_COMPILE_OPTIONS
}
)
target_compile_options
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
EXAMPLE_FMHA_BWD_COMPILE_OPTIONS
}
)
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property
(
GLOBAL PROPERTY RULE_MESSAGES OFF
)
example/ck_tile/05_moe/moe.cpp
View file @
cdb83933
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe.hpp"
#include "ck_tile/host.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include <array>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <torch/torch.h>
//test args
auto
create_args
(
int
argc
,
char
*
argv
[])
{
// get command line data to internal params
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"num_tokens"
,
"10"
,
""
)
.
insert
(
"num_experts"
,
"8"
,
""
)
.
insert
(
"v"
,
"0"
,
"validation"
)
.
insert
(
"hidden_size"
,
"4096"
,
""
)
.
insert
(
"shard_intermediate_size"
,
"4096"
,
""
)
.
insert
(
"topk"
,
"2"
,
"
\n
"
""
)
.
insert
(
"dtype"
,
"fp16"
,
"
\n
"
""
)
.
insert
(
"use_fp8_w8a8"
,
"0"
,
""
)
.
insert
(
"use_int8_w8a16"
,
"0"
,
""
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
//run args assertion and tensor allocation
//and slope/scale gen, tensor copy to device
//init traits/feature instant
//init args, from tensor to tensor pointer and stride, real args to kernel
//call moe
//move result to host
//referrence gen
//validation
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"dtype"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
bool
use_fp8_w8a8
=
arg_parser
.
get_bool
(
"use_fp8_w8a8"
);
bool
use_int8_w8a16
=
arg_parser
.
get_bool
(
"use_int8_w8a16"
);
// auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile
::
index_t
num_tokens
=
arg_parser
.
get_int
(
"num_tokens"
);
ck_tile
::
index_t
num_experts
=
arg_parser
.
get_int
(
"num_experts"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"hidden_size"
);
ck_tile
::
index_t
shard_intermediate_size
=
arg_parser
.
get_int
(
"shard_intermediate_size"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"topk"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
/* log_level = */
(
kname
?
1
:
0
),
stream_warmup
,
stream_repeat
,
arg_parser
.
get_str
(
"timer"
)
==
std
::
string
(
"gpu"
)};
//type config, need type config before tensor gen, and define the acc types
using
TypeConfig
=
MoeTypeConfig
<
DataType
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
UDataType
=
typename
TypeConfig
::
UDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ScaleDataType
=
typename
TypeConfig
::
ScaleDataType
;
//tensor
ck_tile
::
HostTensor
<
GDataType
>
g_host_ref
({
num_experts
,
shard_intermediate_size
/
2
,
hidden_size
/
16
,
2
,
8
});
ck_tile
::
HostTensor
<
UDataType
>
u_host_ref
({
num_experts
,
shard_intermediate_size
/
2
,
hidden_size
/
16
,
2
,
8
});
ck_tile
::
HostTensor
<
DDataType
>
d_host_ref
({
num_experts
,
hidden_size
,
shard_intermediate_size
/
2
/
16
,
2
,
8
});
//reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
num_tokens
,
hidden_size
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
num_experts
,
shard_intermediate_size
/
2
,
hidden_size
});
ck_tile
::
HostTensor
<
UDataType
>
u_host
({
num_experts
,
shard_intermediate_size
/
2
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
num_experts
,
hidden_size
shard_intermediate_size
/
2
});
ck_tile
::
reference_permute
<
GDataType
>
(
g_host_ref
,
g_host
,
{
0
,
1
,
3
,
4
,
2
,
5
})
ck_tile
::
reference_permute
<
GDataType
>
(
u_host_ref
,
u_host
,
{
0
,
1
,
3
,
4
,
2
,
5
})
ck_tile
::
reference_permute
<
GDataType
>
(
d_host_ref
,
d_host
,
{
0
,
1
,
3
,
4
,
2
,
5
})
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
num_tokens
,
hidden_size
});
ck_tile
::
HostTensor
<
FP32
>
sorted_weights
({
num_tokens
,
topk
});
ck_tile
::
HostTensor
<
ck_tile
::
index_t
>
sorted_topk_ids
({
num_tokens
,
topk
});
ck_tile
::
HostTensor
<
ck_tile
::
index_t
>
sorted_expert_ids
({
num_tokens
,
topk
});
ck_tile
::
HostTensor
<
ck_tile
::
index_t
>
sorted_num_tokens_post_padded
({
1
});
//device buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
g_buf
(
g_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
u_buf
(
u_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
d_buf
(
d_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weights
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_topk_ids_buf
(
sorted_topk_ids
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_tiles_buf
(
sorted_num_tokens_post_padded
.
get_element_space_size_in_bytes
());
a_buf
.
ToDevice
(
a_host
.
data
());
g_buf
.
ToDevice
(
g_host
.
data
());
u_buf
.
ToDevice
(
u_host
.
data
());
d_buf
.
ToDevice
(
d_host
.
data
());
//init traits
const
auto
init_traits
=
[
&
](
auto
&
traits
)
{
traits
.
DownPreShuffled
=
0
;
};
//init host args pack internal params to a struct to pass to kernel
const
auto
init_args
=
[
&
](
auto
&
args
)
{
const
ck_tile
::
index_t
stride_a
=
hidden_size
;
const
ck_tile
::
index_t
stride_gu
=
hidden_size
;
const
ck_tile
::
index_t
stride_d
=
shard_intermediate_size
/
2
;
const
ck_tile
::
index_t
stride_o
=
hidden_size
;
const
ck_tile
::
index_t
stride_expert_gu
=
hidden_size
*
shard_intermediate_size
/
2
;
const
ck_tile
::
index_t
stride_expert_d
=
hidden_size
*
shard_intermediate_size
/
2
;
args
.
a_ptr
=
a_buf
.
GetDeviceBuffer
();
args
.
g_ptr
=
g_buf
.
GetDeviceBuffer
();
args
.
u_ptr
=
u_buf
.
GetDeviceBuffer
();
args
.
d_ptr
=
d_buf
.
GetDeviceBuffer
();
args
.
o_ptr
=
o_buf
.
GetDeviceBuffer
();
args
.
sorted_token_ids_ptr
=
sorted_topk_ids_buf
.
GetDeviceBuffer
();
args
.
sorted_weight_ptr
=
sorted_weight_buf
.
GetDeviceBuffer
();
args
.
sorted_expert_ids_ptr
=
sorted_expert_ids_buf
.
GetDeviceBuffer
();
args
.
num_sorted_tiles_ptr
=
sorted_tiles_buf
.
GetDeviceBuffer
();
args
.
stride_a
=
stride_a
;
args
.
stride_gu
=
stride_gu
;
args
.
stride_d
=
stride_d
;
args
.
stride_o
=
stride_o
;
args
.
stride_expert_gu
=
stride_expert_gu
;
args
.
stride_expert_d
=
stride_expert_d
;
args
.
dim_size
=
dim_size
;
args
.
hidden_size
=
hidden_size
;
args
.
num_tokens
=
num_tokens
;
// input number of tokens for current iteration
args
.
num_experts
=
num_experts
;
}
//
constexpr
ck_tile
::
index_t
ts_experts
=
experts_
;
//tiling
using
moe_block_tile_0
=
ck
::
Sequence
<
32
,
// kM_a
128
,
// kN_g/u
128
,
// kN_sub0
32
,
// kK_a
128
// kN_d
>
;
using
moe_block_warps0_0
=
ck
::
Sequence
<
1
,
4
,
1
>
;
//mnk
using
moe_block_warps1_0
=
ck
::
Sequence
<
4
,
1
,
1
>
;
using
moe_warp_tile_0
=
ck
::
Sequence
<
32
,
32
,
16
>
;
// using fmha_warp_tile_4 = ck::Sequence<32, 32, 8>;
using
moe_shape
=
ck
::
tile_program
::
FusedMoeTileShape
<
moe_block_tile_0
,
moe_block_warps0_0
,
moe_warp_tile_0
,
moe_block_warps1_0
,
moe_warp_tile_0
>
;
using
moe_traits
=
ck_tile
::
FusedMoeTraits
<
false
,
//down preshuffle
-
1
,
// index_t kBlockPerCu_ = ,overwrite occupancy if not -1
0
,
//index_t OAtomic_
FusedMoeWeightPermuteEnum
::
permute_b_nr_kr_kw_nw_kv
//FusedMoeWeightPermuteEnum WeightPermute_ =
>
;
using
moe_problem
=
ck_tile
::
FusedMoePipelineProblem
<
ADataType
,
GDataType
,
UDataType
,
DDataType
,
ODataType
,
AccDataType
,
ScaleDataType
,
GateActivation
,
moe_shape
,
moe_traits
>
;
using
moe_pipeline
=
ck_tile
::
FusedMoePipelineNSplit2
<
moe_problem
>
;
using
Hargs
=
ck_tile
::
FusedMoeKernel
::
FusedMoeCommonHargs
;
using
moe_partitioner
=
ck_tile
::
FusedMoeTilePartitioner_PersistentSplitD
<
moe_shape
>
;
\
using
kernel
=
ck_tile
::
FusedMoeKernel
<
moe_partitioner
,
moe_pipeline
>
;
using
Kargs
=
ck_tile
::
FusedMoeKernel
::
FusedMoeCommonKargs
;
Hargs
hargs
;
Kargs
kargs
;
//args to hargs
init_args
[](
hargs
);
\
auto
kargs
=
kernel
::
MakeKargs
(
hargs
);
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
\
const
dim3
grids
=
kernel
::
GridSize
(
cu_count
,
moe_pipeline
::
kBlockPerCu
);
constexpr
dim3
blocks
=
kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
1
>
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
//main
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"fp8"
)
{
return
run
<
ck_tile
::
fp8_t
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
//call creat args
//call run
//return
\ No newline at end of file
example/ck_tile/05_moe/moe.hpp
View file @
cdb83933
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include <type_traits>
template
<
typename
DataType
>
struct
MoeTypeConfig
;
template
<
>
struct
MoeTypeConfig
<
ck_tile
::
half_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
GDataType
=
ck_tile
::
half_t
;
using
UDataType
=
ck_tile
::
half_t
;
using
DDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
ScaleDataType
=
float
;
// data type for lse(logsumexp L_j = max_j + log(l_j))
using
SaccDataType
=
float
;
// data type for first gemm accumulation
// data type for second gemm accumulation
using
ODataType
=
ck_tile
::
half_t
;
};
template
<
>
struct
MoeTypeConfig
<
ck_tile
::
bf16_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
UDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
ScaleDataType
=
float
;
// data type for lse(logsumexp L_j = max_j + log(l_j))
using
SaccDataType
=
float
;
// data type for first gemm accumulation
// data type for second gemm accumulation
using
ODataType
=
ck_tile
::
bf16_t
;
};
template
<
>
struct
MoeTypeConfig
<
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
fp8_t
;
using
GDataType
=
ck_tile
::
fp8_t
;
using
UDataType
=
ck_tile
::
fp8_t
;
using
DDataType
=
ck_tile
::
fp8_t
;
using
AccDataType
=
float
;
using
ScaleDataType
=
float
;
// data type for lse(logsumexp L_j = max_j + log(l_j))
using
SaccDataType
=
float
;
// data type for first gemm accumulation
// data type for second gemm accumulation
using
ODataType
=
ck_tile
::
fp8_t
;
};
template
<
>
struct
MoeTypeConfig
<
ck_tile
::
bf8_t
>
{
using
ADataType
=
ck_tile
::
bf8_t
;
using
GDataType
=
ck_tile
::
bf8_t
;
using
UDataType
=
ck_tile
::
bf8_t
;
using
DDataType
=
ck_tile
::
bf8_t
;
using
AccDataType
=
float
;
using
ScaleDataType
=
float
;
// data type for lse(logsumexp L_j = max_j + log(l_j))
using
SaccDataType
=
float
;
// data type for first gemm accumulation
// data type for second gemm accumulation
using
ODataType
=
ck_tile
::
bf8_t
;
};
//float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
example/ck_tile/CMakeLists.txt
View file @
cdb83933
...
@@ -5,4 +5,5 @@ include_directories(AFTER
...
@@ -5,4 +5,5 @@ include_directories(AFTER
add_subdirectory
(
01_fmha
)
add_subdirectory
(
01_fmha
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
05_moe
)
add_subdirectory
(
06_permute
)
add_subdirectory
(
06_permute
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment