"vscode:/vscode.git/clone" did not exist on "231e3f8dda2b016f0d1bc0922087d304fae2d7ed"
Commit 072dfbfe authored by letaoqin's avatar letaoqin
Browse files

gemm0 debugged

parent 69114f25
...@@ -241,6 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -241,6 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
//ck_tile::FillStepRange<GDataType>{0.0f, 32.0f*128,1.0f}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host); ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host); ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host);
...@@ -282,7 +285,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -282,7 +285,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// output_matrix_2d(a_host, tokens, hidden_size); // output_matrix_2d(a_host, tokens, hidden_size);
std::cout << sorted_token_ids_host << std::endl; std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl;
output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size); // output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std::cout << sorted_expert_ids_host << std::endl; std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl; // std::cout << topk_weight_host << std::endl;
...@@ -290,8 +293,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -290,8 +293,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// done, preparing GPU buffer // done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host); ck_tile::DeviceMem g_perm_buf(g_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host); ck_tile::DeviceMem d_perm_buf(d_host);
ck_tile::DeviceMem sa_buf(sa_host); ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sd_buf(sd_host);
......
...@@ -278,6 +278,7 @@ struct FillConstant ...@@ -278,6 +278,7 @@ struct FillConstant
{ {
T value_{0}; T value_{0};
FillConstant(float value):value_(ck_tile::type_convert<T>(value)){}
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp" #include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -298,7 +299,6 @@ struct FusedMoeGemmGlKernel ...@@ -298,7 +299,6 @@ struct FusedMoeGemmGlKernel
return a_window_; return a_window_;
}(); }();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() { const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0; static_cast<long_index_t>(expert_id) * expert_stride_0;
...@@ -313,6 +313,17 @@ struct FusedMoeGemmGlKernel ...@@ -313,6 +313,17 @@ struct FusedMoeGemmGlKernel
g_view_, g_view_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{idx_n0, 0}); {idx_n0, 0});
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 16; i++)
// {
// printf("in G index is %d , value is: %f\n",
// i,
// ck_tile::type_convert<float>(g_ptr[i]));
// }
// }
return g_window_; return g_window_;
}(); }();
......
...@@ -115,6 +115,7 @@ struct FusedMoeGemmPipeline_General ...@@ -115,6 +115,7 @@ struct FusedMoeGemmPipeline_General
a_window_.get_window_origin(), a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
// load g to register
auto g_global_to_dram_window = make_tile_window( auto g_global_to_dram_window = make_tile_window(
g_window_.get_bottom_tensor_view(), g_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
...@@ -153,27 +154,26 @@ struct FusedMoeGemmPipeline_General ...@@ -153,27 +154,26 @@ struct FusedMoeGemmPipeline_General
} }
#endif #endif
// load g to register
auto g_dram_block = load_tile(g_global_to_dram_window); auto g_dram_block = load_tile(g_global_to_dram_window);
#if 0 #if 0
{ {
constexpr auto a_spans = decltype(g_dram_block)::get_distributed_spans(); constexpr auto g_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0; int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxn) { sweep_tile_span(g_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) { sweep_tile_span(g_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk); constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
counter = counter + 1; counter = counter + 1;
index_t idn_0 = idxn.impl_.at(0); const auto row = tile_idx.at(number<0>{});
index_t idk_0 = idxk.impl_.at(0); const auto col = tile_idx.at(number<1>{});
index_t idk_1 = idxk.impl_.at(1); printf("in G row is %d , col is %d, counter is %d, value is: %f"
printf("in A idn is %d , idk_0 is %d idk_1 is %d, counter is %d, value is: " " \n",
"%f \n", row,
idn_0, col,
idk_0,
idk_1,
counter, counter,
ck_tile::type_convert<float>(g_dram_block(i_j_idx))); ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
} }
...@@ -185,7 +185,7 @@ struct FusedMoeGemmPipeline_General ...@@ -185,7 +185,7 @@ struct FusedMoeGemmPipeline_General
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0; constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(intermediate_size, kK0); const index_t k0_loops = ck_tile::integer_divide_ceil(hidden_size, kK0);
index_t iCounter0 = k0_loops - 1; index_t iCounter0 = k0_loops - 1;
while(iCounter0 > 0) while(iCounter0 > 0)
{ {
...@@ -208,25 +208,25 @@ struct FusedMoeGemmPipeline_General ...@@ -208,25 +208,25 @@ struct FusedMoeGemmPipeline_General
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); gemm_0(s_acc, a_lds_win, g_dram_block);
} }
#if 1 #if 0
{ {
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
int counter = 0; int counter = 0;
//a_spans[0] = 1; //a_spans[0] = 1;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) { sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
constexpr auto i_j_idx = make_tuple(idxn, idxn); constexpr auto i_j_idx = make_tuple(idxm, idxn);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ {
counter = counter + 1; counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0); const auto row = tile_idx.at(number<0>{});
index_t idn_0 = idxn.impl_.at(0); const auto col = tile_idx.at(number<1>{});
index_t idn_1 = idxn.impl_.at(1); printf("in c row is %d , col is %d, counter is %d, value is: "
printf("in A idn is %d , idn_0 is %d, idn_1 is %d, counter is %d, value is: "
"%f \n", "%f \n",
idm_0, row,
idn_0, col,
idn_1,
counter, counter,
ck_tile::type_convert<float>(s_acc(i_j_idx))); ck_tile::type_convert<float>(s_acc(i_j_idx)));
} }
......
...@@ -186,14 +186,6 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -186,14 +186,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
g_outer_dstr_enc, typename WG::BWarpDstrEncoding{}); g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
// constexpr auto g_block_dstr_encode = tile_distribution_encoding<
// sequence<1>,
// tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
// tuple<sequence<0, 1>, sequence<2, 1>>,
// tuple<sequence<0, 1>, sequence<1, 2>>,
// sequence<1, 2, 2>,
// sequence<0, 0, 2>>{};
return make_static_tile_distribution(g_block_dstr_encode); return make_static_tile_distribution(g_block_dstr_encode);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment