Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
40df5c8b
Commit
40df5c8b
authored
Dec 05, 2024
by
letaoqin
Browse files
add weight
parent
b616b254
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
102 additions
and
52 deletions
+102
-52
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+1
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+1
-1
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+8
-1
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+6
-15
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+53
-14
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+33
-21
No files found.
example/ck_tile/17_fused_moe_general/main.cpp
View file @
40df5c8b
...
@@ -375,6 +375,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -375,6 +375,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
// output_matrix_3d(d_host, experts, hidden_size, shared_intermediate_size_1);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
// std::cout << topk_weight_host << std::endl;
// std::cout << topk_weight_host << std::endl;
...
...
include/ck_tile/core/config.hpp
View file @
40df5c8b
...
@@ -228,5 +228,5 @@
...
@@ -228,5 +228,5 @@
#endif
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
1
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
0
#endif
#endif
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
40df5c8b
...
@@ -113,6 +113,8 @@ void reference_fused_moe(
...
@@ -113,6 +113,8 @@ void reference_fused_moe(
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
// first gemm
// first gemm
// if(i_expert == 0)
// printf("ie:%2d, it:%3d \n", i_expert, i_token);
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_0
;
i_n
++
)
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_0
;
i_n
++
)
{
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
...
@@ -122,7 +124,8 @@ void reference_fused_moe(
...
@@ -122,7 +124,8 @@ void reference_fused_moe(
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
}
}
acc_0
(
0
,
i_n
)
=
acc
;
acc_0
(
0
,
i_n
)
=
acc
;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, acc);
}
}
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
intermediate_size_1
});
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
intermediate_size_1
});
...
@@ -135,6 +138,8 @@ void reference_fused_moe(
...
@@ -135,6 +138,8 @@ void reference_fused_moe(
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, y(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
}
}
...
@@ -161,6 +166,8 @@ void reference_fused_moe(
...
@@ -161,6 +166,8 @@ void reference_fused_moe(
{
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
}
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, acc);
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
}
}
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
40df5c8b
...
@@ -247,21 +247,12 @@ struct FusedMoeGemmGlKernel
...
@@ -247,21 +247,12 @@ struct FusedMoeGemmGlKernel
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t
idx_m0
=
__builtin_amdgcn_readfirstlane
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
index_t
idx_m0
=
__builtin_amdgcn_readfirstlane
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
index_t
idx_n0
=
__builtin_amdgcn_readfirstlane
(
sorted_tile_id
*
BlockShape
::
Block_N0
);
index_t
idx_n0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_N0
);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
idx_m0
;
// start block_m
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n",
// a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const
auto
sorted_token_id
=
sorted_tile_id
*
BlockShape
::
Block_M0
;
// start block_m
// position
// position
// index_t token_id =
// index_t token_id =
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
40df5c8b
...
@@ -90,7 +90,8 @@ struct FusedMoeGemmPipeline_General
...
@@ -90,7 +90,8 @@ struct FusedMoeGemmPipeline_General
}
}
template
<
typename
T
>
template
<
typename
T
>
CK_TILE_HOST_DEVICE
static
void
PrintMem
(
T
&
tensor
)
CK_TILE_HOST_DEVICE
static
void
PrintMem
(
T
&
tensor
,
const
char
*
pstr
,
unsigned
int
threadid
=
0
,
unsigned
int
blockid
=
0
)
{
{
constexpr
auto
spans
=
T
::
get_distributed_spans
();
constexpr
auto
spans
=
T
::
get_distributed_spans
();
int
counter
=
0
;
int
counter
=
0
;
...
@@ -99,12 +100,14 @@ struct FusedMoeGemmPipeline_General
...
@@ -99,12 +100,14 @@ struct FusedMoeGemmPipeline_General
constexpr
auto
i_j_idx
=
make_tuple
(
idxn
,
idxk
);
constexpr
auto
i_j_idx
=
make_tuple
(
idxn
,
idxk
);
const
auto
tile_idx
=
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tensor
.
get_tile_distribution
(),
i_j_idx
);
get_x_indices_from_distributed_indices
(
tensor
.
get_tile_distribution
(),
i_j_idx
);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
if
(
threadIdx
.
x
==
threadid
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
blockid
&&
blockIdx
.
z
==
0
)
{
{
const
auto
row
=
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
tile_idx
.
at
(
number
<
1
>
{});
printf
(
"in
G
row is %d , col is %d, counter is %d, value is: %f"
printf
(
"in
%s
row is %d , col is %d, counter is %d, value is: %f"
"
\n
"
,
"
\n
"
,
pstr
,
row
,
row
,
col
,
col
,
counter
,
counter
,
...
@@ -119,14 +122,11 @@ struct FusedMoeGemmPipeline_General
...
@@ -119,14 +122,11 @@ struct FusedMoeGemmPipeline_General
const
GWindow
&
g_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
const
DWindow
&
d_window_
,
OWindow
&
o_window_
,
OWindow
&
o_window_
,
TopkWeightDataType
/*
topk_weight
*/
,
TopkWeightDataType
topk_weight
,
CK_TILE_LDS_ADDR
void
*
smem
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
hidden_size
,
index_t
intermediate_size
)
index_t
intermediate_size
)
{
{
ignore
=
d_window_
;
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
smem_0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
...
@@ -157,13 +157,13 @@ struct FusedMoeGemmPipeline_General
...
@@ -157,13 +157,13 @@ struct FusedMoeGemmPipeline_General
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
a_lds_win
,
a_dram_block
);
#if 0
#if 0
PrintMem(a_dram_block);
PrintMem(a_dram_block
,"A", 0, 1
);
#endif
#endif
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
#if
0
#if
1
PrintMem(g_dram_block);
PrintMem
(
g_dram_block
,
"G"
,
0
,
1
);
#endif
#endif
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
...
@@ -191,7 +191,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -191,7 +191,7 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
}
#if
1
#if
0
PrintMem(s_acc);
PrintMem(s_acc);
#endif
#endif
// relu
// relu
...
@@ -233,6 +233,25 @@ struct FusedMoeGemmPipeline_General
...
@@ -233,6 +233,25 @@ struct FusedMoeGemmPipeline_General
d_window_
.
get_window_origin
(),
d_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
d
=
load_tile
(
d_global_to_dram_window
);
auto
d
=
load_tile
(
d_global_to_dram_window
);
#if 0
PrintMem(d,"D",64);
#endif
// add to LDS
auto
o_alds_view
=
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
atomic_add
>
(
smem_0
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
make_tuple
(
32
,
1
),
number
<
8
>
{},
number
<
1
>
{});
auto
o_alds_win
=
make_tile_window
(
o_alds_view
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
{
0
,
0
});
auto
o_olds_win
=
make_tile_window
(
o_alds_view
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
{
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
ignore
=
o_alds_win
;
constexpr
index_t
kN1
=
BlockShape
::
Block_N1
;
constexpr
index_t
kN1
=
BlockShape
::
Block_N1
;
const
index_t
n1_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kN1
);
const
index_t
n1_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kN1
);
...
@@ -258,12 +277,32 @@ struct FusedMoeGemmPipeline_General
...
@@ -258,12 +277,32 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
gemm_1
(
o_acc
,
y
,
d
);
// block_sync_lds();
tile_elementwise_inout
(
[
&
topk_weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
topk_weight
);
},
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_window_
,
o
);
store_tile
(
o_alds_win
,
o
);
block_sync_lds
();
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
// {
// for(int i = 0; i < 42; i++)
// {
// printf("\n%d value is %f\t", i, type_convert<float>(smem_0[i]));
// }
// }
if
(
threadIdx
.
x
<
64
)
{
auto
o_out
=
load_tile
(
o_olds_win
);
block_sync_lds
();
store_tile
(
o_window_
,
o_out
);
}
}
// ignore = o_olds_win;
// store_tile(o_window_, o);
#if 0
#if 0
PrintMem(o
_acc
);
PrintMem(o
,"O"
);
#endif
#endif
}
// store_tile(o_window_, a_dram_block);
// store_tile(o_window_, a_dram_block);
}
}
};
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
40df5c8b
...
@@ -189,6 +189,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -189,6 +189,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
return
make_static_tile_distribution
(
g_block_dstr_encode
);
return
make_static_tile_distribution
(
g_block_dstr_encode
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_O
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
1
,
2
,
16
>
,
sequence
<
4
,
8
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm0
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm0
()
{
{
...
@@ -276,27 +288,27 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -276,27 +288,27 @@ struct FusedMoeGemmPipelineGeneralPolicy
return
d_block_dstr
;
return
d_block_dstr
;
}
}
template
<
typename
Problem
>
//
template <typename Problem>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_O
()
//
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
//
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
//
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
//
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
//
// using CDataType = typename WarpGemm::CDataType;
constexpr
auto
c_block_outer_dstr_encoding
=
//
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding
<
sequence
<>
,
//
tile_distribution_encoding<sequence<>,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
//
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>>
,
//
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple
<
sequence
<
1
,
2
>>
,
//
tuple<sequence<1, 2>>,
tuple
<
sequence
<
1
,
1
>>
,
//
tuple<sequence<1, 1>>,
sequence
<
1
,
2
>
,
//
sequence<1, 2>,
sequence
<
0
,
0
>>
{};
//
sequence<0, 0>>{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
//
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
//
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
//
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return
c_block_dstr
;
//
return c_block_dstr;
}
//
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsBlockDesc_A
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsBlockDesc_A
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment