Commit eab497e8 authored by letaoqin's avatar letaoqin
Browse files

format

parent 1476d7bb
...@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1}); ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
#if 0 #if 0
# if 1 #if 1
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host); ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host); ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host); ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
...@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host); ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host); ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host); ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
# else #else
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
...@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host); ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host); ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
# endif #endif
// permute weight // permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
...@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
# if 0 #if 0
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host, topk_ids_host,
topk_weight_host, topk_weight_host,
...@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
return 1; return 1;
# endif #endif
#endif #endif
(void)balance; (void)balance;
......
...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>; using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>; using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1) static constexpr ck_tile::index_t BT_ =
BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1)
static constexpr ck_tile::index_t BI_ = static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate (block_n0, block_k1) BlockTIle_::at(ck_tile::number<1>{}); // block intermediate (block_n0, block_k1)
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0) static constexpr ck_tile::index_t BH_ =
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1) BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0)
static constexpr ck_tile::index_t BD_ =
BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1)
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>; using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host); ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host);
// permute weight // permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
......
...@@ -66,448 +66,27 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -66,448 +66,27 @@ struct FusedMoeGemmPipeline_FlatmmGl
} }
}(); }();
static constexpr const char* name = "flatmm_uk"; static constexpr const char* name = "flatmm_gl";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
constexpr index_t smem_0 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge)); return smem_bridge;
} }
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
return MRepeat;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
auto base_coord = threadIdx.x / KLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID_A(const ROW_COORDS coords,
const IndexDataType* sorted_token_ids_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
});
return row_ids;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
{
constexpr index_t WarpGemmLane_M = 16; // TODO: use 16x16
constexpr index_t WarpGemmRepeat_M = BlockShape::Block_M0 / WarpGemmLane_M;
auto base_coord = threadIdx.x % WarpGemmLane_M + base_offset;
array<index_t, WarpGemmRepeat_M> coords;
static_for<0, WarpGemmRepeat_M, 1>{}(
[&](auto i) { coords.at(i) = base_coord + i * WarpGemmLane_M; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
const TopkWeightDataType* sorted_weight_ptr)
{
constexpr index_t n_size = coords.size();
array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
});
return w;
}
CK_TILE_DEVICE auto GetRowCoords_O()
{
constexpr index_t NLans = BlockShape::Block_N1 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / NLans;
constexpr index_t MRepeat = BlockShape::Block_M1 / MLans;
auto base_coord = threadIdx.x / NLans;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
/*
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
*/
template <typename Karg> template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_DEVICE auto operator()(const Karg& kargs,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id, index_t sorted_tile_id,
index_t intermediate_tile_id) index_t intermediate_tile_id)
{ {
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; ignore = kargs;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; ignore = smem;
// w1 (Down, N size) ignore = sorted_tile_id;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; ignore = intermediate_tile_id;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w
index_t interm_idx_nr = __builtin_amdgcn_readfirstlane(
intermediate_tile_id *
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
// printf("bid:%d,%d, sorted_tile_id:%d(, intermediate_tile_id:%d, expert_id:%d,
// interm_idx_nr:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), sorted_tile_id, intermediate_tile_id, expert_id,
// interm_idx_nr);
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID_A(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto a_coords = generate_tuple(
[&](auto i) {
return row_ids_a[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
auto g_window_ = make_tile_window_linear_raw(
g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
// number<decltype(g_win)::NumAccess_NonLinear>{}.rrr2();
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<kAlignmentD>{},
number<1>{});
const auto d_window_ = make_tile_window_linear_raw(
d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
return d_window_;
}();
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
#if 0
auto d_coords = generate_tuple([&](auto i) {
return d_win.cached_coords_[i].get_offset(); },
number<decltype(d_win)::NumAccess_NonLinear>{});
#else
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto d_coords = [&]() {
constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 4;
constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8;
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
constexpr index_t num_offsets_ = Nr_ * Kr0_;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * Kr0_ * Kr1_ * W_;
return generate_tuple(
[&](auto i) {
constexpr auto i_nr_ = number<i % Nr_>{};
constexpr auto i_kr0_ = number<i / Nr_>{};
return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
base_os_;
},
number<num_offsets_>{});
}();
#endif
auto o_coords = generate_tuple(
[&](auto i) {
return row_ids_a[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
},
number<row_ids_a.size()>{});
auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
number<row_ids_a.size()>{});
auto bridge_sst_win = [&]() {
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
desc_),
desc_.get_lengths(),
{0, 0},
dist_);
}();
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, "
"o_coords:%d,%d,%d,%d,%d,%d,%d,%d(%d,%d,%d,%d,%d,%d,%d,%d)\n",
static_cast<int>(blockIdx.x),
static_cast<int>(blockIdx.y),
static_cast<int>(threadIdx.x),
sorted_tile_id,
intermediate_tile_id,
expert_id,
interm_idx_nr,
row_coords_a[0],
row_coords_a[1],
row_coords_a[7],
row_ids_a[0],
row_ids_a[1],
row_ids_a[7],
kr_0 * BlockShape::Block_W0,
g_coords[number<0>{}],
g_coords[number<1>{}],
g_coords[number<7>{}],
o_coords[number<0>{}],
o_coords[number<1>{}],
o_coords[number<2>{}],
o_coords[number<3>{}],
o_coords[number<4>{}],
o_coords[number<5>{}],
o_coords[number<6>{}],
o_coords[number<7>{}],
// (row_ids_a[0] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[1] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[2] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[3] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[4] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[5] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[6] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[7] >= kargs.num_tokens ? 1 : 0)
(row_ids_a[0] < kargs.num_tokens && static_cast<index_t>(o_coords[number<0>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[1] < kargs.num_tokens && static_cast<index_t>(o_coords[number<1>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[2] < kargs.num_tokens && static_cast<index_t>(o_coords[number<2>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[3] < kargs.num_tokens && static_cast<index_t>(o_coords[number<3>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[4] < kargs.num_tokens && static_cast<index_t>(o_coords[number<4>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[5] < kargs.num_tokens && static_cast<index_t>(o_coords[number<5>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[6] < kargs.num_tokens && static_cast<index_t>(o_coords[number<6>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[7] < kargs.num_tokens && static_cast<index_t>(o_coords[number<7>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0)
);
#endif
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
// return ;
//sweep_tile(acc_0,
// [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
sweep_tile(acc_0,
[&](auto idx0, auto idx1) {
fp32x2_t v_ {acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, bridge_sst_win:%d"
"acc:%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f\n",
static_cast<int>(blockIdx.x),
static_cast<int>(blockIdx.y),
static_cast<int>(threadIdx.x),
sorted_tile_id,
intermediate_tile_id,
expert_id,
interm_idx_nr,
row_coords_a[0],
row_coords_a[1],
row_coords_a[7],
row_ids_a[0],
row_ids_a[1],
row_ids_a[7],
kr_0 * BlockShape::Block_W0,
g_coords[number<0>{}],
g_coords[number<1>{}],
g_coords[number<7>{}],
bridge_sst_win.cached_coords_[number<0>{}].get_offset(),
acc_0.get_thread_buffer()[number<0>{}],
acc_0.get_thread_buffer()[number<1>{}],
acc_0.get_thread_buffer()[number<2>{}],
acc_0.get_thread_buffer()[number<3>{}],
acc_0.get_thread_buffer()[number<4>{}],
acc_0.get_thread_buffer()[number<5>{}],
acc_0.get_thread_buffer()[number<6>{}],
acc_0.get_thread_buffer()[number<7>{}],
acc_0.get_thread_buffer()[number<8 + 0>{}],
acc_0.get_thread_buffer()[number<8 + 1>{}],
acc_0.get_thread_buffer()[number<8 + 2>{}],
acc_0.get_thread_buffer()[number<8 + 3>{}],
acc_0.get_thread_buffer()[number<8 + 4>{}],
acc_0.get_thread_buffer()[number<8 + 5>{}],
acc_0.get_thread_buffer()[number<8 + 6>{}],
acc_0.get_thread_buffer()[number<8 + 7>{}]);
#endif
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
d_coords,
o_res,
o_coords,
o_flags,
smem,
kargs.hidden_size, // total n number
w_scale,
BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
BlockShape::Block_N1); // along N
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment