Commit 1caa8198 authored by “letaoqin”'s avatar “letaoqin”
Browse files

write a, g,d and o tensor

parent 84755f74
......@@ -297,7 +297,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens,
experts,
topk,
stride};
stride,
max_num_tokens_padded};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......
......@@ -57,4 +57,44 @@ struct indexing_adaptor_onshot_cached
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
template <typename IndexingType>
struct indexing_adaptor
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
const IndexingType* cached_idx_;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile
......@@ -197,6 +197,7 @@ struct FusedMoeGemmGlKernel
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
};
// TODO: switch karg based on
......@@ -230,17 +231,13 @@ struct FusedMoeGemmGlKernel
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index(i_m and i_n)
// note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n)
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
......@@ -252,17 +249,28 @@ struct FusedMoeGemmGlKernel
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
index_t idx_m0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_M0);
index_t idx_n0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_N0);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n",
// a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const auto sorted_token_id = sorted_tile_id * BlockShape::Block_M0; // start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const index_t* sorted_token_ids_ptr = reinterpret_cast<const index_t*>(
&(reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]));
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
......@@ -276,7 +284,9 @@ struct FusedMoeGemmGlKernel
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_tuple(make_indexing_transform_with_adaptor(
kargs.max_num_tokens_padded,
indexing_adaptor<index_t>{sorted_token_ids_ptr}),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
......@@ -284,7 +294,7 @@ struct FusedMoeGemmGlKernel
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
{idx_m0, 0});
return a_window_;
}();
......@@ -292,52 +302,38 @@ struct FusedMoeGemmGlKernel
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
idx_n0 * kargs.hidden_size;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
make_tuple(BlockShape::Block_N0, kargs.hidden_size),
make_tuple(kargs.hidden_size, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
const auto g_window_ = make_tile_window(
g_view_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
idx_n0;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
make_tuple(kargs.hidden_size, BlockShape::Block_K1),
make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
const auto d_window_ = make_tile_window(
d_view_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
{0, 0});
return d_window_;
}();
......@@ -354,7 +350,9 @@ struct FusedMoeGemmGlKernel
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_tuple(make_indexing_transform_with_adaptor(
kargs.max_num_tokens_padded,
indexing_adaptor<index_t>{sorted_token_ids_ptr}),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
......@@ -362,7 +360,7 @@ struct FusedMoeGemmGlKernel
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
{idx_m0, 0});
return o_window_;
}();
......@@ -374,8 +372,7 @@ struct FusedMoeGemmGlKernel
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size,
kargs.stride_token);
kargs.intermediate_size);
}
};
......
......@@ -105,6 +105,7 @@ struct FusedMoeGemmHostArgs
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
};
// This is scatter/gather b2b group-gemm
......@@ -198,6 +199,7 @@ struct FusedMoeGemmKernel
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
};
// TODO: switch karg based on
......
......@@ -80,16 +80,30 @@ struct FusedMoeGemmPipeline_FlatmmGl
return max(smem_mat_a, smem_bridge);
}
template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs,
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id,
index_t intermediate_tile_id)
index_t hidden_size,
index_t intermediate_size)
{
ignore = kargs;
ignore = a_window_;
ignore = g_window_;
ignore = d_window_;
ignore = o_window_;
ignore = smem;
ignore = sorted_tile_id;
ignore = intermediate_tile_id;
ignore = hidden_size;
ignore = intermediate_size;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment