Commit 854cd8b4 authored by mtgu0705's avatar mtgu0705
Browse files

commit missing files

parent 182e7480
...@@ -268,8 +268,8 @@ int main(int argc, char* argv[]) ...@@ -268,8 +268,8 @@ int main(int argc, char* argv[])
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1})); Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, N})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, N})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
......
...@@ -36,7 +36,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -36,7 +36,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8; using A0DataType = F8;
using B0DataType = I4; using B0DataType = I4;
using EDataType = F32; using EDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
namespace ck { namespace ck {
...@@ -35,7 +36,7 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -35,7 +36,7 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
{ {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(std::is_same<ADataType, BDataType>::value) if constexpr(std::is_same<ADataType, BDataType>::value)
{ {
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche, return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
BlockSize, BlockSize,
...@@ -109,9 +110,28 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -109,9 +110,28 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
if(std::is_same<ADataType, BDataType>::value) if constexpr(std::is_same<ADataType, BDataType>::value)
{ {
std::cerr << "BlockGemmPipeline v3 configuration is not available" << std::endl; return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
} }
else else
{ {
......
...@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K> template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
...@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -249,7 +249,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1 B1 // Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
...@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -258,12 +258,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1 // // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2 // // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1 // Local prefetch A1
...@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -296,13 +297,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf)); b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x, // printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -380,11 +380,12 @@ struct DeviceMoeGemm ...@@ -380,11 +380,12 @@ struct DeviceMoeGemm
// } // }
// else // else
{ {
const auto kernel = kernel_moe_gemm_gather<GridwiseGemm, const auto kernel = kernel_moe_gemm<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel); RunKernel(kernel);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment