Commit d0405504 authored by carlushuang's avatar carlushuang
Browse files

update

parent 9d3cdd21
...@@ -105,9 +105,9 @@ auto create_args(int argc, char* argv[]) ...@@ -105,9 +105,9 @@ auto create_args(int argc, char* argv[])
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert( .insert(
"gate_only", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("balance", .insert("balance",
"1", "0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)") "if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -121,6 +121,8 @@ auto create_args(int argc, char* argv[]) ...@@ -121,6 +121,8 @@ auto create_args(int argc, char* argv[])
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW> template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
std::cout << "xxxx " << __LINE__ << std::flush << std::endl;
ck_tile::index_t tokens = arg_parser.get_int("t"); ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e"); ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k"); ck_tile::index_t topk = arg_parser.get_int("k");
...@@ -150,7 +152,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -150,7 +152,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
int balance = arg_parser.get_int("balance"); int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp"); int tp = arg_parser.get_int("tp");
ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp; // w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
base_str += "x" + prec_w;
if(prec_i != prec_o)
base_str += "=" + prec_o;
if(fused_quant != 0)
{
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
}
return base_str;
}();
std::cout << "[" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
...@@ -167,36 +190,36 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -167,36 +190,36 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size, hidden_size}); ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, intermediate_size, hidden_size}); ck_tile::HostTensor<DDataType> d_host({experts, shared_intermediate_size_1, hidden_size});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens}); ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size}); ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({intermediate_size}); ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({intermediate_size}); // smooth-quant ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
int max_num_tokens_padded = topk * tokens + experts * (block_m - 1); int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded}); ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded}); ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host( ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
{(max_num_tokens_padded + block_m - 1) / block_m}); {(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1}); ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w);
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_perm_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_perm_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host); ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host); ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host); ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host); ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting // do moe sorting
if(balance) if(balance)
{ {
...@@ -223,6 +246,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -223,6 +246,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host); ck_tile::DeviceMem g_perm_buf(g_perm_host);
...@@ -238,24 +265,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -238,24 +265,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
base_str += "x" + prec_w;
if(prec_i != prec_o)
base_str += "=" + prec_o;
if(fused_quant != 0)
{
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
}
return base_str;
}();
std::cout << "[" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
fused_moegemm_traits traits{prec_i, fused_moegemm_traits traits{prec_i,
prec_w, prec_w,
prec_o, prec_o,
...@@ -280,7 +289,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -280,7 +289,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf.GetDeviceBuffer(), sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size, hidden_size,
intermediate_size, shared_intermediate_size_0,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -302,8 +311,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,8 +311,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
#else #else
std::size_t flop_gemm_0 = 2 * tokens * topk * shared_intermediate_size * hidden_size; std::size_t flop_gemm_0 = 2 * tokens * topk * shared_intermediate_size_0 * hidden_size;
std::size_t flop_gemm_1 = 2 * tokens * topk * hidden_size * hidden_size; std::size_t flop_gemm_1 = 2 * tokens * topk * shared_intermediate_size_1 * hidden_size;
double tflops = (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ave_time) * 1e-3) / 1e12; double tflops = (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ave_time) * 1e-3) / 1e12;
// float gb_per_sec = num_byte / 1.E6 / ave_time; // float gb_per_sec = num_byte / 1.E6 / ave_time;
...@@ -331,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -331,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens, tokens,
experts, experts,
hidden_size, hidden_size,
intermediate_size, shared_intermediate_size_0,
topk, topk,
gate_only); gate_only);
......
...@@ -590,6 +590,22 @@ struct HostTensor ...@@ -590,6 +590,22 @@ struct HostTensor
size() * FromSize / ToSize}; size() * FromSize / ToSize};
} }
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << t.mData[idx];
}
os << "]";
return os;
}
Descriptor mDesc; Descriptor mDesc;
Data mData; Data mData;
}; };
......
...@@ -21,6 +21,7 @@ namespace ck_tile { ...@@ -21,6 +21,7 @@ namespace ck_tile {
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
...@@ -94,7 +95,7 @@ void reference_fused_moe( ...@@ -94,7 +95,7 @@ void reference_fused_moe(
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size}); ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * (block_m - 1); int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert(); // assert();
auto f = [&](auto i_flatten) { auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m; ck_tile::index_t i_tile = i_flatten / block_m;
......
...@@ -6,10 +6,14 @@ ...@@ -6,10 +6,14 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -99,7 +99,7 @@ struct FusedMoeGemmHostArgs ...@@ -99,7 +99,7 @@ struct FusedMoeGemmHostArgs
const void* num_sorted_tiles_ptr; // [1] const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
index_t topk; // need this? index_t topk; // need this?
...@@ -176,7 +176,7 @@ struct FusedMoeGemmKernel ...@@ -176,7 +176,7 @@ struct FusedMoeGemmKernel
const void* num_sorted_tiles_ptr; const void* num_sorted_tiles_ptr;
index_t hidden_size; // k index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this) index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
index_t topk; // need this? index_t topk; // need this?
...@@ -198,7 +198,7 @@ struct FusedMoeGemmKernel ...@@ -198,7 +198,7 @@ struct FusedMoeGemmKernel
{ {
constexpr index_t block_m = BlockShape::Block_M0; constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded = int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * (block_m - 1); hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size); return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment