Commit 84b27d75 authored by coderfeli's avatar coderfeli
Browse files

merge max_token_id and fix err

parents b3ae04f8 83be79ba
...@@ -192,9 +192,10 @@ int main(int argc, char* argv[]) ...@@ -192,9 +192,10 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 9;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t valid_tile_num = 8;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t batch = 64; ck::index_t batch = 64;
ck::index_t topk = 2; ck::index_t topk = 2;
...@@ -234,15 +235,17 @@ int main(int argc, char* argv[]) ...@@ -234,15 +235,17 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = i;
} }
int token_per_tile = tokens / sorted_tile_num; int token_per_tile = tokens / valid_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
for (int i = 0; i < SORTED_SIZE; i++) { for (int i = 0; i < sorted_size; i++) {
int tile_off = i % sorted_tile_size; int tile_off = i % valid_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
{ {
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24); sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
...@@ -294,6 +297,7 @@ int main(int argc, char* argv[]) ...@@ -294,6 +297,7 @@ int main(int argc, char* argv[])
d1_e_n.savetxt("d1_e_n.txt", "int"); d1_e_n.savetxt("d1_e_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
...@@ -302,6 +306,7 @@ int main(int argc, char* argv[]) ...@@ -302,6 +306,7 @@ int main(int argc, char* argv[])
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
...@@ -323,15 +328,16 @@ int main(int argc, char* argv[]) ...@@ -323,15 +328,16 @@ int main(int argc, char* argv[])
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(), expert_ids_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(), max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
topk, topk,
SORTED_SIZE, sorted_size,
N, N,
K, K,
StrideA, StrideA,
...@@ -352,9 +358,9 @@ int main(int argc, char* argv[]) ...@@ -352,9 +358,9 @@ int main(int argc, char* argv[])
if (time_kernel) { if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; std::size_t flop = std::size_t(2) * sorted_size * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; sizeof(A0DataType) * sorted_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * sorted_size * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -383,26 +389,26 @@ int main(int argc, char* argv[]) ...@@ -383,26 +389,26 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_t_k_n, PassThrough{}, PassThrough{}, PassThrough{}); sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k, b0_e_n_k, c_t_k_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m) for(int m = 0; m < valid_size; ++m)
{ {
const int fuse_t = sorted_token_ids.mData[m]; const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff; const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24; const int topk_id = (fuse_t & 0xff000000) >> 24;
printf("m %d fuset %d %d %d\n",m, fuse_t, t, topk_id); // printf("m %d fuset %d %d %d\n",m, fuse_t, t, topk_id);
if (t >= tokens) if (t >= tokens)
{ {
continue; continue;
} }
const int e = expert_ids(m / sorted_tile_size); const int e = expert_ids(m / MPerBlock);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n)); cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n));
printf("m %d n %d topk %d token %d %f %f\n",m, n,topk_id, t, e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n)); // printf("m %d n %d topk %d token %d %f %f\n",m, n,topk_id, t, e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n));
} }
} }
......
...@@ -185,9 +185,11 @@ int main(int argc, char* argv[]) ...@@ -185,9 +185,11 @@ int main(int argc, char* argv[])
ck::index_t N = 6144; ck::index_t N = 6144;
ck::index_t K = 8192; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 9;
ck::index_t valid_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t sorted_size = sorted_tile_num * sorted_tile_size;
ck::index_t valid_size = valid_tile_num * sorted_tile_size;
ck::index_t tokens = 64; ck::index_t tokens = 64;
if(argc == 1) if(argc == 1)
...@@ -223,14 +225,16 @@ int main(int argc, char* argv[]) ...@@ -223,14 +225,16 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
for (int i = 0; i < sorted_tile_num; i++) { for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i; expert_ids.mData[i] = i;
} }
int token_per_tile = tokens / sorted_tile_num; int token_per_tile = tokens / sorted_tile_num;
int tokenid = 0; int tokenid = 0;
// sorted_token_ids.mData[0] = 0; // sorted_token_ids.mData[0] = 0;
for (int i = 0; i < SORTED_SIZE; i++) { for (int i = 0; i < sorted_size; i++) {
int tile_off = i % sorted_tile_size; int tile_off = i % sorted_tile_size;
if(tile_off < token_per_tile) if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++; sorted_token_ids.mData[i] = tokenid++;
...@@ -238,12 +242,12 @@ int main(int argc, char* argv[]) ...@@ -238,12 +242,12 @@ int main(int argc, char* argv[])
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); Tensor<A0DataType> a0_m_k(HostTensorDescriptor({sorted_size, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_m_n(HostTensorDescriptor({sorted_size, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({SORTED_SIZE, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
...@@ -280,6 +284,7 @@ int main(int argc, char* argv[]) ...@@ -280,6 +284,7 @@ int main(int argc, char* argv[])
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
...@@ -294,6 +299,7 @@ int main(int argc, char* argv[]) ...@@ -294,6 +299,7 @@ int main(int argc, char* argv[])
d2_e_n.savetxt("d2_e_n.txt", "int"); d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
...@@ -318,6 +324,7 @@ int main(int argc, char* argv[]) ...@@ -318,6 +324,7 @@ int main(int argc, char* argv[])
auto argument = auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(), expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(), a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
...@@ -325,7 +332,7 @@ int main(int argc, char* argv[]) ...@@ -325,7 +332,7 @@ int main(int argc, char* argv[])
d2_device_buf.GetDeviceBuffer()}, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
SORTED_SIZE, sorted_size,
N, N,
K, K,
StrideA, StrideA,
...@@ -347,9 +354,9 @@ int main(int argc, char* argv[]) ...@@ -347,9 +354,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero // not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; std::size_t flop = std::size_t(2) * sorted_size * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; sizeof(A0DataType) * sorted_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * sorted_size * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -380,7 +387,7 @@ int main(int argc, char* argv[]) ...@@ -380,7 +387,7 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_m_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a0_m_k, b0_e_n_k, d0_m_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -500,6 +500,7 @@ struct DeviceMoeGemm ...@@ -500,6 +500,7 @@ struct DeviceMoeGemm
static auto MakeArgument(const void* p_sorted_token_ids, static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids, const void* p_sorted_expert_ids,
const void* p_max_token_id,
const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
...@@ -520,6 +521,7 @@ struct DeviceMoeGemm ...@@ -520,6 +521,7 @@ struct DeviceMoeGemm
{ {
return Argument{static_cast<const index_t*>(p_sorted_token_ids), return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids), static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const index_t*>(p_max_token_id),
static_cast<const ADataType*>(p_a), static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
...@@ -560,7 +562,7 @@ struct DeviceMoeGemm ...@@ -560,7 +562,7 @@ struct DeviceMoeGemm
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
// assert(0, "no impl"); // assert(0, "no impl");
return std::make_unique<Argument>(nullptr, nullptr, return std::make_unique<Argument>(nullptr, nullptr, nullptr,
static_cast<const ADataType*>(p_a), static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
......
...@@ -46,6 +46,7 @@ __global__ void ...@@ -46,6 +46,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids, karg.p_sorted_token_ids,
karg.p_sorted_expert_ids, karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
...@@ -618,6 +619,7 @@ struct GridwiseMoeGemmGather ...@@ -618,6 +619,7 @@ struct GridwiseMoeGemmGather
__host__ Argument( __host__ Argument(
const index_t* p_sorted_token_ids_, const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_, const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,
const ADataType* p_a_grid_, const ADataType* p_a_grid_,
const BDataType* p_b_grid_, const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_, std::array<const void*, NumDTensor> p_ds_grid_,
...@@ -639,6 +641,7 @@ struct GridwiseMoeGemmGather ...@@ -639,6 +641,7 @@ struct GridwiseMoeGemmGather
p_sorted_token_ids{p_sorted_token_ids_}, p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_}, p_sorted_expert_ids{p_sorted_expert_ids_},
p_max_token_id{p_max_token_id_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_ds_grid{}, p_ds_grid{},
...@@ -659,6 +662,8 @@ struct GridwiseMoeGemmGather ...@@ -659,6 +662,8 @@ struct GridwiseMoeGemmGather
const index_t * p_sorted_token_ids; const index_t * p_sorted_token_ids;
const index_t * p_sorted_expert_ids; const index_t * p_sorted_expert_ids;
const index_t * p_max_token_id;
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
DsGridPointer p_ds_grid; DsGridPointer p_ds_grid;
...@@ -1123,6 +1128,7 @@ struct GridwiseMoeGemmGather ...@@ -1123,6 +1128,7 @@ struct GridwiseMoeGemmGather
__device__ static void Run( __device__ static void Run(
const index_t* p_sorted_token_ids, const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids, const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
const ADataType* p_a_grid, const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, DsGridPointer& p_ds_grid,
...@@ -1150,6 +1156,8 @@ struct GridwiseMoeGemmGather ...@@ -1150,6 +1156,8 @@ struct GridwiseMoeGemmGather
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
...@@ -1160,13 +1168,14 @@ struct GridwiseMoeGemmGather ...@@ -1160,13 +1168,14 @@ struct GridwiseMoeGemmGather
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); // static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
const index_t t0 = p_sorted_token_ids[block_m_id * MPerBlock]; if(token_pos >= max_token_id || token0 >= problem.NumTokens)
if((t0 & 0xffffff) >= problem.NumTokens)
return; return;
const index_t topk_id = (t0 & 0xff000000) >> 24; const index_t topk_id = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) { static_for<0, AMRepeats, 1>{}([&](auto m0) {
gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K; const index_t token_offset = (token_pos + m0 < max_token_id) ?
(p_sorted_token_ids[token_pos + m0] & 0xffffff) : problem.NumTokens;
gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
// const index_t m_block_data_idx_on_grid = // const index_t m_block_data_idx_on_grid =
......
...@@ -46,6 +46,7 @@ __global__ void ...@@ -46,6 +46,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids, karg.p_sorted_token_ids,
karg.p_sorted_expert_ids, karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
...@@ -614,6 +615,7 @@ struct GridwiseMoeGemmScatter ...@@ -614,6 +615,7 @@ struct GridwiseMoeGemmScatter
__host__ Argument( __host__ Argument(
const index_t* p_sorted_token_ids_, const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_, const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,
const ADataType* p_a_grid_, const ADataType* p_a_grid_,
const BDataType* p_b_grid_, const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_, std::array<const void*, NumDTensor> p_ds_grid_,
...@@ -634,6 +636,7 @@ struct GridwiseMoeGemmScatter ...@@ -634,6 +636,7 @@ struct GridwiseMoeGemmScatter
p_sorted_token_ids{p_sorted_token_ids_}, p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_}, p_sorted_expert_ids{p_sorted_expert_ids_},
p_max_token_id{p_max_token_id_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_ds_grid{}, p_ds_grid{},
...@@ -654,6 +657,7 @@ struct GridwiseMoeGemmScatter ...@@ -654,6 +657,7 @@ struct GridwiseMoeGemmScatter
const index_t * p_sorted_token_ids; const index_t * p_sorted_token_ids;
const index_t * p_sorted_expert_ids; const index_t * p_sorted_expert_ids;
const index_t * p_max_token_id;
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
DsGridPointer p_ds_grid; DsGridPointer p_ds_grid;
...@@ -1118,6 +1122,7 @@ struct GridwiseMoeGemmScatter ...@@ -1118,6 +1122,7 @@ struct GridwiseMoeGemmScatter
__device__ static void Run( __device__ static void Run(
const index_t* p_sorted_token_ids, const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids, const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
const ADataType* p_a_grid, const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, DsGridPointer& p_ds_grid,
...@@ -1143,13 +1148,14 @@ struct GridwiseMoeGemmScatter ...@@ -1143,13 +1148,14 @@ struct GridwiseMoeGemmScatter
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
const index_t t0 = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); if(m_block_data_idx_on_grid >= max_token_id || token0 >= problem.NumTokens)
if(t0 >= problem.NumTokens)
return; return;
// N0, K0, Blocksize*KPack // N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
......
...@@ -30,6 +30,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -30,6 +30,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
{ {
Argument(const Tensor<ck::index_t>& sorted_token_ids, Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
...@@ -39,6 +40,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -39,6 +40,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: sorted_token_ids_{sorted_token_ids}, : sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids}, expert_ids_{expert_ids},
max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size}, sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k}, a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
...@@ -51,6 +53,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -51,6 +53,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ck::index_t>& sorted_token_ids_; const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_; const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_; index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_; const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
...@@ -70,7 +73,6 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -70,7 +73,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_.mDesc.GetLengths()[1]; const int K = arg.a_t_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0}; AccDataType v_acc{0};
ComputeTypeA v_a{0}; ComputeTypeA v_a{0};
ComputeTypeB v_b{0}; ComputeTypeB v_b{0};
...@@ -110,12 +112,12 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -110,12 +112,12 @@ struct ReferenceMoeGemm : public device::BaseOperator
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c; arg.c_t_k_n_(t, topk_id, n) = v_c;
printf("ref m %d n %d t %d topk %d v %f\n", m, n, t, topk_id, v_c);
} }
}; };
const ck::index_t max_token_id = arg.max_token_id_(0);
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.sorted_token_ids_.GetLengths()[0], arg.c_t_k_n_.mDesc.GetLengths()[2])( f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -138,6 +140,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -138,6 +140,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids, static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k, const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
...@@ -146,7 +149,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -146,7 +149,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_t_k_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_t_k, b_e_n_k, c_t_k_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -33,6 +33,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -33,6 +33,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{ {
Argument(const Tensor<ck::index_t>& sorted_token_ids, Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
...@@ -45,6 +46,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -45,6 +46,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: sorted_token_ids_{sorted_token_ids}, : sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids}, expert_ids_{expert_ids},
max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size}, sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k}, a_m_k_{a_m_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
...@@ -60,6 +62,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -60,6 +62,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& sorted_token_ids_; const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_; const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_; index_t sorted_tile_size_;
const Tensor<ADataType>& a_m_k_; const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
...@@ -126,9 +129,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -126,9 +129,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
} }
}; };
const ck::index_t max_token_id = arg.max_token_id_(0);
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
f_mk_kn_mn, arg.a_m_k_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( f_mk_kn_mn, max_token_id, arg.c_t_n_.mDesc.GetLengths()[1])(
1); std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -150,6 +155,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -150,6 +155,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids, static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids, const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
...@@ -161,7 +167,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -161,7 +167,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment