Commit 1caa8198 authored by “letaoqin”'s avatar “letaoqin”
Browse files

write a, g,d and o tensor

parent 84755f74
...@@ -297,7 +297,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -297,7 +297,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens, tokens,
experts, experts,
topk, topk,
stride}; stride,
max_num_tokens_padded};
float ave_time = fused_moegemm( float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......
...@@ -57,4 +57,44 @@ struct indexing_adaptor_onshot_cached ...@@ -57,4 +57,44 @@ struct indexing_adaptor_onshot_cached
return ck_tile::is_known_at_compile_time<IndexingType>::value; return ck_tile::is_known_at_compile_time<IndexingType>::value;
} }
}; };
template <typename IndexingType>
struct indexing_adaptor
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
const IndexingType* cached_idx_;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -197,6 +197,7 @@ struct FusedMoeGemmGlKernel ...@@ -197,6 +197,7 @@ struct FusedMoeGemmGlKernel
index_t topk; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
}; };
// TODO: switch karg based on // TODO: switch karg based on
...@@ -230,17 +231,13 @@ struct FusedMoeGemmGlKernel ...@@ -230,17 +231,13 @@ struct FusedMoeGemmGlKernel
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size; index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index(i_m and i_n) // note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n)
const auto [sorted_tile_id, intermediate_tile_id] = const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size); Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles) if(sorted_tile_id >= num_sorted_tiles)
...@@ -252,17 +249,28 @@ struct FusedMoeGemmGlKernel ...@@ -252,17 +249,28 @@ struct FusedMoeGemmGlKernel
// index along intermediate_size // index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id * // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0); // BlockShape::Block_N0);
index_t interm_idx_nr = index_t idx_m0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_M0);
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0); index_t idx_n0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_N0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] // const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
index_t token_id = // printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n",
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]; // a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const auto sorted_token_id = sorted_tile_id * BlockShape::Block_M0; // start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight = auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id]; reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const index_t* sorted_token_ids_ptr = reinterpret_cast<const index_t*>(
&(reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]));
const auto a_window = [&]() { const auto a_window = [&]() {
// A is already pre-padded in previous kernel // A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
...@@ -276,7 +284,9 @@ struct FusedMoeGemmGlKernel ...@@ -276,7 +284,9 @@ struct FusedMoeGemmGlKernel
// gather is here use indexing transform // gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view( const auto a_gather_view_ = transform_tensor_view(
a_view_, a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), make_tuple(make_indexing_transform_with_adaptor(
kargs.max_num_tokens_padded,
indexing_adaptor<index_t>{sorted_token_ids_ptr}),
make_pass_through_transform(kargs.hidden_size)), make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -284,7 +294,7 @@ struct FusedMoeGemmGlKernel ...@@ -284,7 +294,7 @@ struct FusedMoeGemmGlKernel
const auto a_window_ = make_tile_window( const auto a_window_ = make_tile_window(
a_gather_view_, a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0}); {idx_m0, 0});
return a_window_; return a_window_;
}(); }();
...@@ -292,52 +302,38 @@ struct FusedMoeGemmGlKernel ...@@ -292,52 +302,38 @@ struct FusedMoeGemmGlKernel
const auto g_window = [&]() { const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) + const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 + static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0; idx_n0 * kargs.hidden_size;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>( const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}), make_tuple(BlockShape::Block_N0, kargs.hidden_size),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1), make_tuple(kargs.hidden_size, 1),
number<Pipeline::kAlignmentG>{}, number<Pipeline::kAlignmentG>{},
number<1>{}); number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_, const auto g_window_ = make_tile_window(
make_tuple(number<BlockShape::Block_Nr0>{}, g_view_,
number<BlockShape::Block_Kr0>{}, make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
number<BlockShape::Block_W0>{}), {0, 0});
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_; return g_window_;
}(); }();
const auto d_window = [&]() { const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) + const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 + static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1; idx_n0;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm // note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr, d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1), make_tuple(kargs.hidden_size, BlockShape::Block_K1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1), make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{}, number<Pipeline::kAlignmentD>{},
number<1>{}); number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_, const auto d_window_ = make_tile_window(
make_tuple(number<BlockShape::Block_Nr1>{}, d_view_,
number<BlockShape::Block_Kr1>{}, make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
number<BlockShape::Block_W1>{}), {0, 0});
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_; return d_window_;
}(); }();
...@@ -354,7 +350,9 @@ struct FusedMoeGemmGlKernel ...@@ -354,7 +350,9 @@ struct FusedMoeGemmGlKernel
// gather is here // gather is here
auto o_scatter_view_ = transform_tensor_view( auto o_scatter_view_ = transform_tensor_view(
o_view_, o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), make_tuple(make_indexing_transform_with_adaptor(
kargs.max_num_tokens_padded,
indexing_adaptor<index_t>{sorted_token_ids_ptr}),
make_pass_through_transform(kargs.hidden_size)), make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -362,7 +360,7 @@ struct FusedMoeGemmGlKernel ...@@ -362,7 +360,7 @@ struct FusedMoeGemmGlKernel
auto o_window_ = make_tile_window( auto o_window_ = make_tile_window(
o_scatter_view_, o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}), make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0}); {idx_m0, 0});
return o_window_; return o_window_;
}(); }();
...@@ -374,8 +372,7 @@ struct FusedMoeGemmGlKernel ...@@ -374,8 +372,7 @@ struct FusedMoeGemmGlKernel
topk_weight, topk_weight,
smem, smem,
kargs.hidden_size, kargs.hidden_size,
kargs.intermediate_size, kargs.intermediate_size);
kargs.stride_token);
} }
}; };
......
...@@ -104,7 +104,8 @@ struct FusedMoeGemmHostArgs ...@@ -104,7 +104,8 @@ struct FusedMoeGemmHostArgs
index_t num_experts; // number of groups index_t num_experts; // number of groups
index_t topk; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
}; };
// This is scatter/gather b2b group-gemm // This is scatter/gather b2b group-gemm
...@@ -198,6 +199,7 @@ struct FusedMoeGemmKernel ...@@ -198,6 +199,7 @@ struct FusedMoeGemmKernel
index_t topk; // need this? index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
}; };
// TODO: switch karg based on // TODO: switch karg based on
......
...@@ -80,16 +80,30 @@ struct FusedMoeGemmPipeline_FlatmmGl ...@@ -80,16 +80,30 @@ struct FusedMoeGemmPipeline_FlatmmGl
return max(smem_mat_a, smem_bridge); return max(smem_mat_a, smem_bridge);
} }
template <typename Karg> // this is the thread-offset along row/col
CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id, index_t hidden_size,
index_t intermediate_tile_id) index_t intermediate_size)
{ {
ignore = kargs; ignore = a_window_;
ignore = g_window_;
ignore = d_window_;
ignore = o_window_;
ignore = smem; ignore = smem;
ignore = sorted_tile_id; ignore = hidden_size;
ignore = intermediate_tile_id; ignore = intermediate_size;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment