Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d4a0a8ee
Commit
d4a0a8ee
authored
Dec 13, 2024
by
letaoqin
Browse files
add gelu and weight
parent
d846292c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
27 deletions
+69
-27
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+15
-5
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+1
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+22
-6
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+31
-16
No files found.
example/ck_tile/17_fused_moe_general/main.cpp
View file @
d4a0a8ee
...
...
@@ -264,6 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{(
max_num_tokens_padded
+
block_m
-
1
)
/
block_m
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
sorted_token_ids_host
.
SetValue
(
max_num_tokens_padded
);
if
(
init
==
0
)
{
ck_tile
::
FillStepRange
<
ADataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
a_host
);
...
...
@@ -280,9 +281,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
d_host
);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
// ck_tile::FillConstant<DDataType>{1}(d_host);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
...
...
@@ -301,6 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
}
else
if
(
init
==
3
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
1
}(
a_host
);
ck_tile
::
FillConstant
<
GDataType
>
{
1
}(
g_host
);
ck_tile
::
FillConstant
<
DDataType
>
{
1
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
,
seed
,
true
}(
topk_weight_host
);
}
// permute weight
// ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
...
...
@@ -393,7 +403,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
// std::cout << topk_weight_host << std::endl;
//
std::cout << sorted_weight_host << std::endl;
std
::
cout
<<
sorted_weight_host
<<
std
::
endl
;
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
...
...
@@ -490,7 +500,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
c_dev
=
c_buf
.
ToHost
<
ADataType
>
();
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
o_dev
<<
std
::
endl
;
//
std::cout << o_dev << std::endl;
// std::cout << c_dev << std::endl;
// int count = 0;
// std::cout << "[";
...
...
include/ck_tile/host/host_tensor.hpp
View file @
d4a0a8ee
...
...
@@ -349,6 +349,7 @@ struct HostTensor
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
void
SetZero
()
{
std
::
fill
(
mData
.
begin
(),
mData
.
end
(),
0
);
}
void
SetValue
(
int
value
)
{
std
::
fill
(
mData
.
begin
(),
mData
.
end
(),
value
);
}
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
d4a0a8ee
...
...
@@ -252,12 +252,12 @@ struct FusedMoeGemmGlKernel
index_t
idx_n0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_N0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
idx_m0
;
// start block_m
// position
//
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
//
const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
//
// position
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
//
auto topk_weight =
//
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const
index_t
*
sorted_token_ids_ptr
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
);
...
...
@@ -374,12 +374,28 @@ struct FusedMoeGemmGlKernel
return
o_window_
;
}();
const
auto
w_window
=
[
&
]()
{
const
TopkWeightDataType
*
w_ptr
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
);
const
auto
w_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
w_ptr
,
make_tuple
(
kargs
.
max_num_tokens_padded
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
const
auto
w_window_
=
make_tile_window
(
w_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
{
idx_m0
});
return
w_window_
;
}();
// do compute yeah
Pipeline
{}(
a_window
,
g_window
,
d_window
,
w_window
,
o_window
,
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
,
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
d4a0a8ee
...
...
@@ -89,14 +89,6 @@ struct FusedMoeGemmPipeline_General
// return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
static
void
PrintMem
(
T
&
tensor
,
const
char
*
pstr
,
unsigned
int
threadid
=
0
,
unsigned
int
blockid
=
0
)
...
...
@@ -129,20 +121,21 @@ struct FusedMoeGemmPipeline_General
typename
GWindow
,
typename
DWindow
,
typename
OWindow
,
typename
CWindow
>
typename
CWindow
,
typename
WWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
const
WWindow
&
w_window_
,
OWindow
&
o_window_
,
TopkWeightDataType
topk_weight
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
/*intermediate_size*/
,
CWindow
&
c_window_
)
{
ignore
=
topk_weight
;
ignore
=
c_window_
;
ignore
=
hidden_size
;
ignore
=
w_window_
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
GDataType
*
smem_1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
GDataType
*>
(
smem_0
+
GetSmemSizeA
()
/
sizeof
(
ADataType
));
...
...
@@ -233,8 +226,8 @@ struct FusedMoeGemmPipeline_General
PrintMem(s_acc, "S", 0);
#endif
// relu
//
const auto activation = ck_tile::element_wise::Gelu{};
//
tile_elementwise_inout(activation, s_acc, s_acc);
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
// cast data to YDataType
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
...
...
@@ -260,6 +253,28 @@ struct FusedMoeGemmPipeline_General
constexpr
auto
gemm_1
=
Policy
::
template
GetBlockGemm1
<
Problem
>();
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
OaccBlockTileType
{};
constexpr
auto
w_dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
s_acc
.
get_tile_distribution
().
get_static_tile_distribution_encoding
(),
sequence
<
1
>
{}));
auto
w_global_to_dram_window
=
make_tile_window
(
w_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
w_window_
.
get_window_origin
(),
w_dstr
);
auto
w
=
load_tile
(
w_global_to_dram_window
);
float
weight
=
type_convert
<
float
>
(
w
.
get_thread_buffer
()[
0
]);
#if 0
constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size();
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
for(int i = 0; i < w_buffer_size; i++)
{
printf("\n len: %d, w[%d]: %f weight: %f", w_buffer_size, i, type_convert<float>(w.get_thread_buffer()[i]), topk_weight);
}
}
#endif
ignore
=
w
;
// y data
auto
bridge_llds_win
=
make_tile_window
(
bridge_lds_view
,
...
...
@@ -308,7 +323,7 @@ struct FusedMoeGemmPipeline_General
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
auto
save_o
=
[
&
]()
{
if
(
blockIdx
.
x
==
0
&&
(
blockIdx
.
y
==
0
||
blockIdx
.
y
==
1
)
&&
blockIdx
.
z
==
0
)
//
if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0)
{
if
(
threadIdx
.
x
<
64
)
{
...
...
@@ -352,8 +367,8 @@ struct FusedMoeGemmPipeline_General
gemm_1
(
o_acc
,
y
,
d
);
// block_sync_lds();
//
tile_elementwise_inout(
//
[&
topk_
weight](auto& x) { x = x * type_convert<float>(
topk_
weight); }, o_acc);
tile_elementwise_inout
(
[
&
weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
weight
);
},
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_alds_win
,
o
);
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment