Commit d4a0a8ee authored by letaoqin's avatar letaoqin
Browse files

add gelu and weight

parent d846292c
...@@ -264,6 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -264,6 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{(max_num_tokens_padded + block_m - 1) / block_m}); {(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1}); ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
sorted_token_ids_host.SetValue(max_num_tokens_padded);
if(init == 0) if(init == 0)
{ {
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host); ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
...@@ -280,9 +281,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -280,9 +281,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
// ck_tile::FillConstant<DDataType>{1}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host); ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host); ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host); ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
...@@ -301,6 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -301,6 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host); ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host); ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
} }
else if(init == 3)
{
ck_tile::FillConstant<ADataType>{1}(a_host);
ck_tile::FillConstant<GDataType>{1}(g_host);
ck_tile::FillConstant<DDataType>{1}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f, seed, true}(
topk_weight_host);
}
// permute weight // permute weight
// ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); // ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
...@@ -393,7 +403,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -393,7 +403,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << sorted_expert_ids_host << std::endl; std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl; // std::cout << topk_weight_host << std::endl;
// std::cout << sorted_weight_host << std::endl; std::cout << sorted_weight_host << std::endl;
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_host); ck_tile::DeviceMem g_perm_buf(g_host);
...@@ -490,7 +500,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -490,7 +500,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
auto c_dev = c_buf.ToHost<ADataType>(); auto c_dev = c_buf.ToHost<ADataType>();
std::cout << std::endl; std::cout << std::endl;
std::cout << o_dev << std::endl; // std::cout << o_dev << std::endl;
// std::cout << c_dev << std::endl; // std::cout << c_dev << std::endl;
// int count = 0; // int count = 0;
// std::cout << "["; // std::cout << "[";
......
...@@ -349,6 +349,7 @@ struct HostTensor ...@@ -349,6 +349,7 @@ struct HostTensor
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); } // void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
void SetZero() { std::fill(mData.begin(), mData.end(), 0); } void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
void SetValue(int value) { std::fill(mData.begin(), mData.end(), value); }
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
......
...@@ -252,12 +252,12 @@ struct FusedMoeGemmGlKernel ...@@ -252,12 +252,12 @@ struct FusedMoeGemmGlKernel
index_t idx_n0 = index_t idx_n0 =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0); __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] // const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m // const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// position // // position
auto topk_weight = // auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id]; // reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const index_t* sorted_token_ids_ptr = const index_t* sorted_token_ids_ptr =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr); reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr);
...@@ -374,12 +374,28 @@ struct FusedMoeGemmGlKernel ...@@ -374,12 +374,28 @@ struct FusedMoeGemmGlKernel
return o_window_; return o_window_;
}(); }();
const auto w_window = [&]() {
const TopkWeightDataType* w_ptr = reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr);
const auto w_view_ = make_naive_tensor_view<address_space_enum::global>(
w_ptr,
make_tuple(kargs.max_num_tokens_padded),
make_tuple(1),
number<1>{},
number<1>{});
const auto w_window_ = make_tile_window(
w_view_,
make_tuple(number<BlockShape::Block_M0>{}),
{idx_m0});
return w_window_;
}();
// do compute yeah // do compute yeah
Pipeline{}(a_window, Pipeline{}(a_window,
g_window, g_window,
d_window, d_window,
w_window,
o_window, o_window,
topk_weight,
smem, smem,
kargs.hidden_size, kargs.hidden_size,
kargs.intermediate_size, kargs.intermediate_size,
......
...@@ -89,14 +89,6 @@ struct FusedMoeGemmPipeline_General ...@@ -89,14 +89,6 @@ struct FusedMoeGemmPipeline_General
// return Policy::template GetSmemSize<Problem>(); // return Policy::template GetSmemSize<Problem>();
} }
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE static void CK_TILE_HOST_DEVICE static void
PrintMem(T& tensor, const char* pstr, unsigned int threadid = 0, unsigned int blockid = 0) PrintMem(T& tensor, const char* pstr, unsigned int threadid = 0, unsigned int blockid = 0)
...@@ -129,20 +121,21 @@ struct FusedMoeGemmPipeline_General ...@@ -129,20 +121,21 @@ struct FusedMoeGemmPipeline_General
typename GWindow, typename GWindow,
typename DWindow, typename DWindow,
typename OWindow, typename OWindow,
typename CWindow> typename CWindow,
typename WWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_, CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_, const GWindow& g_window_,
const DWindow& d_window_, const DWindow& d_window_,
const WWindow& w_window_,
OWindow& o_window_, OWindow& o_window_,
TopkWeightDataType topk_weight,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t hidden_size, index_t hidden_size,
index_t /*intermediate_size*/, index_t /*intermediate_size*/,
CWindow& c_window_) CWindow& c_window_)
{ {
ignore = topk_weight;
ignore = c_window_; ignore = c_window_;
ignore = hidden_size; ignore = hidden_size;
ignore = w_window_;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>( CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType)); smem_0 + GetSmemSizeA() / sizeof(ADataType));
...@@ -233,8 +226,8 @@ struct FusedMoeGemmPipeline_General ...@@ -233,8 +226,8 @@ struct FusedMoeGemmPipeline_General
PrintMem(s_acc, "S", 0); PrintMem(s_acc, "S", 0);
#endif #endif
// relu // relu
// const auto activation = ck_tile::element_wise::Gelu{}; const auto activation = ck_tile::element_wise::Gelu{};
// tile_elementwise_inout(activation, s_acc, s_acc); tile_elementwise_inout(activation, s_acc, s_acc);
// cast data to YDataType // cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc); auto y_pre = cast_tile<YDataType>(s_acc);
...@@ -260,6 +253,28 @@ struct FusedMoeGemmPipeline_General ...@@ -260,6 +253,28 @@ struct FusedMoeGemmPipeline_General
constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>(); constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>();
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{}; auto o_acc = OaccBlockTileType{};
constexpr auto w_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
s_acc.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
auto w_global_to_dram_window = make_tile_window(
w_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}),
w_window_.get_window_origin(),
w_dstr);
auto w = load_tile(w_global_to_dram_window);
float weight = type_convert<float>(w.get_thread_buffer()[0]);
#if 0
constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size();
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
for(int i = 0; i < w_buffer_size; i++)
{
printf("\n len: %d, w[%d]: %f weight: %f", w_buffer_size, i, type_convert<float>(w.get_thread_buffer()[i]), topk_weight);
}
}
#endif
ignore = w;
// y data // y data
auto bridge_llds_win = auto bridge_llds_win =
make_tile_window(bridge_lds_view, make_tile_window(bridge_lds_view,
...@@ -308,7 +323,7 @@ struct FusedMoeGemmPipeline_General ...@@ -308,7 +323,7 @@ struct FusedMoeGemmPipeline_General
Policy::template MakeGlobalTileDistribution_O<Problem>()); Policy::template MakeGlobalTileDistribution_O<Problem>());
auto save_o = [&]() { auto save_o = [&]() {
if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0) //if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0)
{ {
if(threadIdx.x < 64) if(threadIdx.x < 64)
{ {
...@@ -352,8 +367,8 @@ struct FusedMoeGemmPipeline_General ...@@ -352,8 +367,8 @@ struct FusedMoeGemmPipeline_General
gemm_1(o_acc, y, d); gemm_1(o_acc, y, d);
// block_sync_lds(); // block_sync_lds();
// tile_elementwise_inout( tile_elementwise_inout(
// [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc); [&weight](auto& x) { x = x * type_convert<float>(weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc); auto o = cast_tile<ODataType>(o_acc);
store_tile(o_alds_win, o); store_tile(o_alds_win, o);
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment