Commit 12301455 authored by coderfeli's avatar coderfeli
Browse files

gemm2 result ok

parent 7ba5bff4
...@@ -33,8 +33,8 @@ using F32 = float; ...@@ -33,8 +33,8 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F16; using A0DataType = F8;
using B0DataType = F16; using B0DataType = F8;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
...@@ -172,10 +172,10 @@ int main(int argc, char* argv[]) ...@@ -172,10 +172,10 @@ int main(int argc, char* argv[])
// experts = 8 // experts = 8
// per expert: // per expert:
// GEMM shape // GEMM shape
ck::index_t N = 128; ck::index_t N = 6144;
ck::index_t K = 1024; ck::index_t K = 8192;
ck::index_t experts = 8; ck::index_t experts = 8;
ck::index_t sorted_tile_num = 2; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 32; ck::index_t tokens = 32;
...@@ -341,6 +341,7 @@ int main(int argc, char* argv[]) ...@@ -341,6 +341,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
if (time_kernel) { if (time_kernel) {
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K;
...@@ -357,9 +358,10 @@ int main(int argc, char* argv[]) ...@@ -357,9 +358,10 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
//gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1});
// e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_n({tokens, N}); Tensor<CShuffleDataType> c_t_n({tokens, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
......
...@@ -279,7 +279,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -279,7 +279,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
...@@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
......
...@@ -1170,7 +1170,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1170,7 +1170,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true, true,
BlockwiseGemmPipe::GlobalBufferNum>( BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1397,7 +1397,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1397,7 +1397,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N;
printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
}); });
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment