Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c6c3c142
Commit
c6c3c142
authored
Nov 06, 2024
by
carlushuang
Browse files
update cpu reference
parent
a288c57c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
217 additions
and
152 deletions
+217
-152
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+33
-152
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+183
-0
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
c6c3c142
...
@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
//
using AccDataType = typename TypeConfig::AccDataType;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
...
@@ -313,154 +313,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -313,154 +313,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
#if 0
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
// reference
a_host
,
if(fused_add != 0)
g_host
,
{
d_host
,
// fused pre_add/pre_add_store
sa_host
,
// TODO we accumulate directly to a_host for simplcity here...
sg_host
,
std::transform(a_host.mData.cbegin(),
sd_host
,
a_host.mData.cend(),
sy_host
,
x_residual_host.mData.cbegin(),
o_host
,
a_host.mData.begin(),
sorted_token_ids_host
,
[](auto x_, auto r_) {
sorted_weight_host
,
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
sorted_expert_ids_host
,
ck_tile::type_convert<ComputeDataType>(r_);
num_sorted_tiles_host
,
return ck_tile::type_convert<ADataType>(o_);
topk_ids_host
,
});
block_m
,
}
tokens
,
ck_tile::reference_layernorm2d_fwd<ADataType,
experts
,
GammaDataType,
hidden_size
,
BetaDataType,
intermediate_size
,
ComputeDataType,
topk
,
YDataType,
gate_only
);
MeanDataType,
InvStdDataType>(
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
a_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
if(fused_quant != 0)
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
{
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
int N_ = acc_.mDesc.get_lengths()[1];
if(fused_quant == 1)
{
for(int n_ = 0; n_ < N_; n_++)
{
// input smooth outlier
acc_(m_, n_) =
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
}
}
ComputeDataType absmax = static_cast<ComputeDataType>(0);
for(int n_ = 0; n_ < N_; n_++)
{
const auto a = ck_tile::abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
}
};
ck_tile::reference_layernorm2d_fwd<ADataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(a_host,
gamma_host,
beta_host,
y_host_ref,
mean_host_ref,
invStd_host_ref,
epsilon,
dquant_functor);
}
else
{
ck_tile::reference_layernorm2d_fwd<ADataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
a_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
}
y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
if(fused_add == 1)
{
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(y_residual_host_dev,
a_host,
std::string("ADD Error: Incorrect results!"),
rtol,
atol);
}
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
y_host_dev.begin() + i_r * stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
y_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * stride,
y_residual_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
a_host.begin() + i_r * stride, a_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
if(fused_quant == 1)
{
y_scale_buf.FromDevice(y_scale_host_dev.data());
pass &= ck_tile::check_err(y_scale_host_dev,
y_scale_host_ref,
std::string("SCALE Error: Incorrect results!"),
rtol,
atol);
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
#else
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
#endif
}
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
return
pass
;
}
}
...
...
include/ck_tile/host.hpp
View file @
c6c3c142
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
0 → 100644
View file @
c6c3c142
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template
<
typename
AccDataType
,
// you only need to explcitly set this one
typename
Activation
,
// ck_tile::element_wise::Gelu
typename
ADataType
,
typename
GDataType
,
typename
DDataType
,
typename
ODataType
,
typename
AScaleDataType
,
typename
GScaleDataType
,
typename
DScaleDataType
,
typename
YSmoothScaleDataType
,
typename
TopkWeightDataType
,
typename
IndexDataType
>
void
reference_fused_moe
(
const
ck_tile
::
HostTensor
<
ADataType
>&
a_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
GDataType
>&
g_host
,
// [experts, interme_size, hidden_size]
const
ck_tile
::
HostTensor
<
DDataType
>&
d_host
,
// [experts, hidden_size, hidden_size]
const
ck_tile
::
HostTensor
<
AScaleDataType
>&
sa_host
,
// [tokens, 1],
const
ck_tile
::
HostTensor
<
GScaleDataType
>&
sg_host
,
// [experts, 1, interme_size]
const
ck_tile
::
HostTensor
<
DScaleDataType
>&
sd_host
,
// [experts, 1, hidden_size],
const
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>&
sy_host
,
// [experts, 1, interme_size]
ck_tile
::
HostTensor
<
ODataType
>&
o_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_token_ids_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
TopkWeightDataType
>&
sorted_weight_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_expert_ids_host
,
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
num_sorted_tiles_host
,
// [1]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
token_ids_host
,
// [tokens, topk] --> ugly!!! remove in the future
ck_tile
::
index_t
block_m
,
ck_tile
::
index_t
tokens
,
ck_tile
::
index_t
experts
,
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
{
assert
(
sorted_token_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_weight_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_expert_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
num_sorted_tiles_host
.
get_element_size
()
==
1
);
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
];
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
{
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
return
-
1
;
// TODO: not correct!!
};
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
(
block_m
-
1
);
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size
});
// first gemm
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
{
acc
+=
type_convert
<
AccDataType
>
(
a_host
(
i_token
,
i_k
))
*
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
}
acc_0
(
0
,
i_n
)
=
acc
;
}
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
hidden_size
});
if
(
gate_only
)
{
assert
(
hidden_size
==
intermediate_size
);
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
}
}
else
{
assert
(
hidden_size
*
2
==
intermediate_size
);
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
hidden_size
);
// TODO: elementwise mul
}
}
// second gemm
ck_tile
::
HostTensor
<
AccDataType
>
acc_1
({
1
,
hidden_size
});
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
}
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
out_topk_tokens
(
i_token
,
i_topk
,
i_n
)
=
acc_1
(
0
,
i_n
);
}
};
make_ParallelTensorFunctor
(
f
,
max_num_tokens_padded
)(
std
::
thread
::
hardware_concurrency
());
// reduce
auto
r
=
[
&
](
auto
i_token
)
{
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
ODataType
acc
=
type_convert
<
ODataType
>
(
0
);
for
(
ck_tile
::
index_t
i_topk
=
0
;
i_topk
<
topk
;
i_topk
++
)
{
acc
+=
out_topk_tokens
(
i_token
,
i_topk
,
i_n
);
}
o_host
(
i_token
,
i_n
)
=
acc
;
}
};
make_ParallelTensorFunctor
(
r
,
tokens
)(
std
::
thread
::
hardware_concurrency
());
(
void
)
num_sorted_tiles_host
;
(
void
)
sa_host
;
(
void
)
sg_host
;
(
void
)
sd_host
;
(
void
)
sy_host
;
}
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment