Commit 2bf0057a authored by dummycoderfe's avatar dummycoderfe
Browse files

add moe_sorting & check ok

parent 24d996aa
...@@ -494,11 +494,6 @@ include_directories(BEFORE ...@@ -494,11 +494,6 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS} ${HIP_INCLUDE_DIRS}
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
# topk-softmax
This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
```
This will result in an executable `build/bin/tile_example_moe_sorting`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-st_o row stride of output/indices, -1 means same as topk (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "128", "number of input tokens")
.insert("e", "8", "number of experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
.insert("st_i", "-1", "row stride of input, -1 means same as topk")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int stride_input = args.get_int("st_i");
int unit_size = args.get_int("unit");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
int max_output_ids = (topk * tokens * experts + (unit_size - 1)) / unit_size * unit_size;
if(stride_input < 0)
{
stride_input = topk;
}
assert(stride_input >= topk);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > experts)
{
printf("topk:%d value should be smaller than, or equal to number of experts:%d\n",
topk,
experts);
return false;
}
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {stride_input, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {stride_input, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem expert_ids_dev(expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
moe_sorting_trait trait{input_prec, weight_prec, experts, topk, unit_size, tokens};
moe_sorting_kargs karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
tokens,
unit_size,
experts,
topk};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, ms:%f , ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
experts,
topk,
stride_input,
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
expert_ids_dev.FromDevice(expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(sorted_ids_ref.data(),
sorted_weights_ref.data(),
expert_ids_ref.data(),
total_tokens_post_pad,
weights_host.data(),
topk_ids_host.data(),
topk_ids_host.size() / topk,
experts,
topk,
unit_size);
float atol = 1e-6;
float rtol = 1e-6;
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), rtol, atol);
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
rtol,
atol);
rtn &= ck_tile::check_err(
expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), rtol, atol);
rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0)
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_sorting_api.hpp"
float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32")
{
using index_t = ck_tile::index_t;
using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>;
using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using kernel = ck_tile::MoeSortingKernel<ms_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = 1;
const dim3 blocks = ck_tile::max(t.experts, ck_tile::get_warp_size());
const size_t lds_size = ((blocks.x + 1) * t.experts + (t.experts + 1)) * sizeof(index_t);
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<64, 1>(kernel{}, grids, blocks, lds_size, kargs));
return ave_time;
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp"
struct moe_sorting_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
int topk;
int unit_size;
int tokens;
};
struct moe_sorting_kargs : public ck_tile::MoeSortingHostArgs
{
};
float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_config s);
# #!/bin/sh
# EXE=./build/bin/tile_example_moe_sorting
# for pr_i in "fp16" "bf16" ; do
# $EXE -pr_i=$pr_i -t=80 -e=17
# $EXE -pr_i=$pr_i -t=111 -e=117
# $EXE -pr_i=$pr_i -t=1000 -e=55
# $EXE -pr_i=$pr_i -t=99 -e=180
# $EXE -pr_i=$pr_i -t=175 -e=64 -k=8
# $EXE -pr_i=$pr_i -t=65 -e=8 -k=2
# $EXE -pr_i=$pr_i -t=1 -e=25
# $EXE -pr_i=$pr_i -t=31 -e=19 -k=15
# $EXE -pr_i=$pr_i -t=81 -e=37 -k=7
# $EXE -pr_i=$pr_i -t=199 -e=128 -k=13
# $EXE -pr_i=$pr_i -t=23 -e=1 -k=1
# $EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31
# $EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12
# $EXE -pr_i=$pr_i -t=1 -e=1 -k=1
# $EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5
# $EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17
# done
...@@ -11,3 +11,4 @@ add_subdirectory(06_permute) ...@@ -11,3 +11,4 @@ add_subdirectory(06_permute)
add_subdirectory(09_topk_softmax) add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d) add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_moe_sorting)
...@@ -29,5 +29,6 @@ ...@@ -29,5 +29,6 @@
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
WeightType* sorted_weight_buf,
IndexType* sorted_expert_ids_ptr,
index_t& sub_x_cnt,
const WeightType* weights_ptr,
const IndexType* topk_ids_ptr,
const index_t num_token,
const index_t experts,
const index_t topk,
const index_t sub_x)
{
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(sub_x, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(experts,
std::vector<WeightType>(sub_x, 0));
std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0);
for(index_t t = 0; t < num_token; t++)
{
for(index_t k = 0; k < topk; k++)
{
index_t e = *(topk_ids_ptr + t * topk + k);
WeightType w = *(weights_ptr + t * topk + k);
index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * sub_x - 1)
{
expert_slices[e]++;
index_t new_size = expert_slices[e] * sub_x;
expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size);
for(index_t idx = (expert_slices[e] - 1) * sub_x; idx < new_size; idx++)
{
expert_tokens[e][idx] = num_token;
expert_token_weights[e][idx] = 0;
}
}
expert_tokens[e][idx] = t;
expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++;
}
}
IndexType* tokens = sorted_token_ids_ptr;
WeightType* weights = sorted_weight_buf;
IndexType* erp_ids = sorted_expert_ids_ptr;
for(index_t e = 0; e < experts; e++)
{
memcpy(tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * sub_x);
tokens += expert_slices[e] * sub_x;
memcpy(
weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * sub_x);
weights += expert_slices[e] * sub_x;
for(index_t s = 0; s < expert_slices[e]; s++)
{
erp_ids[s] = e;
sub_x_cnt++;
}
erp_ids += expert_slices[e];
}
return;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct MoeSortingHostArgs
{
const void* p_topk_ids;
const void* p_weights;
void* sorted_token_ids;
void* sorted_weights;
void* expert_ids;
void* total_tokens_post_pad;
index_t tokens;
index_t unit_size;
index_t num_experts;
index_t topk;
};
template <typename Pipeline_>
struct MoeSortingKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType;
typedef MoeSortingHostArgs MoeSortingKargs;
using Kargs = MoeSortingKargs;
using Hargs = MoeSortingHostArgs;
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) { return h; }
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{
return row * total_col + col;
}
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
index_t* sorted_token_ids,
WeightType* sorted_weights,
index_t* expert_ids,
index_t* total_tokens_post_pad,
const index_t num_experts,
const index_t unit_size,
const size_t numel,
const index_t topk) const
{
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ index_t shared_mem[];
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i)
{
tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0;
}
__syncthreads();
if(threadIdx.x < num_experts)
{
tokens_cnts[calc_index(num_experts, 0, threadIdx.x)] = 0;
for(int i = 1; i <= blockDim.x; ++i)
{
tokens_cnts[calc_index(num_experts, i, threadIdx.x)] +=
tokens_cnts[calc_index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
if(threadIdx.x == 0)
{
cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i)
{
cumsum[i] =
cumsum[i - 1] +
MAX(CEILDIV(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size),
1) *
unit_size;
}
*total_tokens_post_pad = cumsum[num_experts] / unit_size;
}
__syncthreads();
if(threadIdx.x < num_experts)
{
for(int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += unit_size)
{
expert_ids[i / unit_size] = threadIdx.x;
}
}
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
index_t expert_id = topk_id[i];
index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i / topk;
sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, threadIdx.x, expert_id)];
}
const index_t prefill_token = numel / topk;
if(threadIdx.x < num_experts)
{
index_t expert_offset =
cumsum[threadIdx.x] + tokens_cnts[calc_index(num_experts, blockDim.x, threadIdx.x)];
while(expert_offset < cumsum[threadIdx.x + 1])
{
sorted_token_ids[expert_offset] = prefill_token;
sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++;
}
}
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const size_t numel = kargs.tokens * kargs.topk;
return moe_align_block_size_kernel(static_cast<const IndexType *>(kargs.p_topk_ids),
static_cast<const WeightType *>(kargs.p_weights),
static_cast<IndexType *>(kargs.sorted_token_ids),
static_cast<WeightType *>(kargs.sorted_weights),
static_cast<IndexType *>(kargs.expert_ids),
static_cast<IndexType *>(kargs.total_tokens_post_pad),
kargs.num_experts,
kargs.unit_size,
numel,
kargs.topk);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace ck_tile {
template <typename Problem_, typename Policy_ = MoeSortingPolicy>
struct MoeSortingPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* sorted_token_ids,
// WeightType* sorted_weights,
// index_t* expert_ids,
// index_t* total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct MoeSortingPolicy
{
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename IndexType_, typename WeightType_>
struct MoeSortingProblem
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment